mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): wip node editor
This commit is contained in:
@@ -15,6 +15,7 @@ import postprocessingReducer from 'features/parameters/store/postprocessingSlice
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
@@ -85,6 +86,7 @@ const rootReducer = combineReducers({
|
||||
results: resultsReducer,
|
||||
uploads: uploadsReducer,
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
});
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
@@ -97,9 +99,9 @@ const rootPersistConfig = getPersistConfig({
|
||||
...galleryBlacklist,
|
||||
...lightboxBlacklist,
|
||||
...apiBlacklist,
|
||||
// for now, never persist the results/uploads slices
|
||||
'results',
|
||||
'uploads',
|
||||
// 'nodes',
|
||||
],
|
||||
debounce: 300,
|
||||
});
|
||||
|
||||
@@ -1,175 +1,160 @@
|
||||
import WorkInProgress from './WorkInProgress';
|
||||
import ReactFlow, {
|
||||
applyEdgeChanges,
|
||||
applyNodeChanges,
|
||||
Background,
|
||||
Controls,
|
||||
Edge,
|
||||
Node,
|
||||
NodeTypes,
|
||||
OnEdgesChange,
|
||||
OnNodesChange,
|
||||
} from 'reactflow';
|
||||
// import WorkInProgress from './WorkInProgress';
|
||||
// import ReactFlow, {
|
||||
// applyEdgeChanges,
|
||||
// applyNodeChanges,
|
||||
// Background,
|
||||
// Controls,
|
||||
// Edge,
|
||||
// Handle,
|
||||
// Node,
|
||||
// NodeTypes,
|
||||
// OnEdgesChange,
|
||||
// OnNodesChange,
|
||||
// Position,
|
||||
// } from 'reactflow';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import {
|
||||
FunctionComponent,
|
||||
ReactNode,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { filter, map } from 'lodash';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Input,
|
||||
Select,
|
||||
Switch,
|
||||
Text,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberDecrementStepper,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
// import 'reactflow/dist/style.css';
|
||||
// import {
|
||||
// Fragment,
|
||||
// FunctionComponent,
|
||||
// ReactNode,
|
||||
// useCallback,
|
||||
// useMemo,
|
||||
// useState,
|
||||
// } from 'react';
|
||||
// import { OpenAPIV3 } from 'openapi-types';
|
||||
// import { filter, map, reduce } from 'lodash';
|
||||
// import {
|
||||
// Box,
|
||||
// Flex,
|
||||
// FormControl,
|
||||
// FormLabel,
|
||||
// Input,
|
||||
// Select,
|
||||
// Switch,
|
||||
// Text,
|
||||
// NumberInput,
|
||||
// NumberInputField,
|
||||
// NumberInputStepper,
|
||||
// NumberIncrementStepper,
|
||||
// NumberDecrementStepper,
|
||||
// Tooltip,
|
||||
// chakra,
|
||||
// Badge,
|
||||
// Heading,
|
||||
// VStack,
|
||||
// HStack,
|
||||
// Menu,
|
||||
// MenuButton,
|
||||
// MenuList,
|
||||
// MenuItem,
|
||||
// MenuItemOption,
|
||||
// MenuGroup,
|
||||
// MenuOptionGroup,
|
||||
// MenuDivider,
|
||||
// IconButton,
|
||||
// } from '@chakra-ui/react';
|
||||
// import { FaPlus } from 'react-icons/fa';
|
||||
// import {
|
||||
// FIELD_NAMES as FIELD_NAMES,
|
||||
// FIELDS,
|
||||
// INVOCATION_NAMES as INVOCATION_NAMES,
|
||||
// INVOCATIONS,
|
||||
// } from 'features/nodeEditor/constants';
|
||||
|
||||
// grab the openapi schema json
|
||||
async function fetchOpenAPI(): Promise<OpenAPIV3.Document> {
|
||||
const response = await fetch(`openapi.json`);
|
||||
const jsonData = await response.json();
|
||||
return jsonData;
|
||||
}
|
||||
// console.log('invocations', INVOCATIONS);
|
||||
|
||||
// build an individual input element based on the schema
|
||||
const buildInputElement = (field: OpenAPIV3.SchemaObject): ReactNode => {
|
||||
if (field.type === 'string') {
|
||||
// `string` fields may either be a text input or an enum ie select
|
||||
if (field.enum) {
|
||||
return (
|
||||
<Select defaultValue={field.default}>
|
||||
{field.enum?.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
// const nodeTypes = reduce(
|
||||
// INVOCATIONS,
|
||||
// (acc, val, key) => {
|
||||
// acc[key] = val.component;
|
||||
// return acc;
|
||||
// },
|
||||
// {} as NodeTypes
|
||||
// );
|
||||
|
||||
return <Input defaultValue={field.default}></Input>;
|
||||
} else if (field.type === 'boolean') {
|
||||
return <Switch defaultValue={field.default}></Switch>;
|
||||
} else if (['integer', 'number'].includes(field.type as string)) {
|
||||
return (
|
||||
<NumberInput defaultValue={field.default}>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
);
|
||||
}
|
||||
};
|
||||
// console.log('nodeTypes', nodeTypes);
|
||||
|
||||
// build object of invocations UI components keyed by their name
|
||||
const buildInvocations = async () => {
|
||||
// get schema
|
||||
const openApi = await fetchOpenAPI();
|
||||
// // make initial nodes one of every node for now
|
||||
// let n = 0;
|
||||
// const initialNodes = map(INVOCATIONS, (i) => ({
|
||||
// id: i.type,
|
||||
// type: i.title,
|
||||
// position: { x: (n += 20), y: (n += 20) },
|
||||
// data: {},
|
||||
// }));
|
||||
|
||||
// filter out non-invocation schemas, kinda janky but dunno if there's another way really
|
||||
// outputs an array
|
||||
const filteredSchemas = filter(
|
||||
openApi.components?.schemas,
|
||||
(_schema, key) =>
|
||||
key.includes('Invocation') && !key.includes('InvocationOutput')
|
||||
);
|
||||
// console.log('initialNodes', initialNodes);
|
||||
|
||||
// actually build the UI components
|
||||
// reduce the array of schemas into an object of react function components, keyed by name (eg NodeTypes)
|
||||
const invocations = filteredSchemas.reduce<NodeTypes>((acc, val, key) => {
|
||||
// we know these are always SchemaObjects not ReferenceObjects
|
||||
const schema = val as OpenAPIV3.SchemaObject;
|
||||
// export default function NodesWIP() {
|
||||
// const [nodes, setNodes] = useState<Node[]>([]);
|
||||
// const [edges, setEdges] = useState<Edge[]>([]);
|
||||
|
||||
// we know `title` will always be present
|
||||
const name = schema.title!.replace('Invocation', '');
|
||||
// const onNodesChange: OnNodesChange = useCallback(
|
||||
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
const fields = filter(
|
||||
schema.properties,
|
||||
(prop, key) => !['type', 'id'].includes(key)
|
||||
);
|
||||
// const onEdgesChange: OnEdgesChange = useCallback(
|
||||
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// assemble!
|
||||
acc[name] = () => (
|
||||
<Box sx={{ padding: 4, bg: 'base.800', borderRadius: 'md' }}>
|
||||
<Flex flexDirection="column">
|
||||
<Text>{name}</Text>
|
||||
{fields.map((field, i) => {
|
||||
const f = field as OpenAPIV3.SchemaObject;
|
||||
return (
|
||||
<Tooltip key={i} label={f.description} placement="top" hasArrow>
|
||||
<FormControl>
|
||||
<FormLabel>{f.title}</FormLabel>
|
||||
{buildInputElement(f)}
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
// return (
|
||||
// <Box
|
||||
// sx={{
|
||||
// position: 'relative',
|
||||
// width: 'full',
|
||||
// height: 'full',
|
||||
// borderRadius: 'md',
|
||||
// }}
|
||||
// >
|
||||
// <ReactFlow
|
||||
// nodeTypes={nodeTypes}
|
||||
// nodes={nodes}
|
||||
// edges={edges}
|
||||
// onNodesChange={onNodesChange}
|
||||
// onEdgesChange={onEdgesChange}
|
||||
// >
|
||||
// <Background />
|
||||
// <Controls />
|
||||
// </ReactFlow>
|
||||
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||
// {FIELD_NAMES.map((field) => (
|
||||
// <Badge
|
||||
// key={field}
|
||||
// colorScheme={FIELDS[field].color}
|
||||
// sx={{ userSelect: 'none' }}
|
||||
// >
|
||||
// {field}
|
||||
// </Badge>
|
||||
// ))}
|
||||
// </HStack>
|
||||
// <Menu>
|
||||
// <MenuButton
|
||||
// as={IconButton}
|
||||
// aria-label="Options"
|
||||
// icon={<FaPlus />}
|
||||
// sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
// />
|
||||
// <MenuList>
|
||||
// {INVOCATION_NAMES.map((name) => {
|
||||
// const invocation = INVOCATIONS[name];
|
||||
// return (
|
||||
// <Tooltip
|
||||
// key={name}
|
||||
// label={invocation.description}
|
||||
// placement="end"
|
||||
// hasArrow
|
||||
// >
|
||||
// <MenuItem>{invocation.title}</MenuItem>
|
||||
// </Tooltip>
|
||||
// );
|
||||
// })}
|
||||
// </MenuList>
|
||||
// </Menu>
|
||||
// </Box>
|
||||
// );
|
||||
// }
|
||||
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
console.log(invocations);
|
||||
|
||||
return invocations;
|
||||
};
|
||||
|
||||
const invocations = await buildInvocations();
|
||||
|
||||
// make initial nodes one of every node for now
|
||||
let n = 0;
|
||||
const initialNodes = map(
|
||||
invocations,
|
||||
(i, type): Node => ({
|
||||
id: type,
|
||||
type,
|
||||
position: { x: (n += 20), y: (n += 20) },
|
||||
data: {},
|
||||
})
|
||||
);
|
||||
|
||||
export default function NodesWIP() {
|
||||
const nodeTypes = useMemo(() => invocations, []);
|
||||
const [nodes, setNodes] = useState<Node[]>(initialNodes);
|
||||
const [edges, setEdges] = useState<Edge[]>([]);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||
[]
|
||||
);
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
(changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||
[]
|
||||
);
|
||||
return (
|
||||
<WorkInProgress>
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
>
|
||||
<Background />
|
||||
<Controls />
|
||||
</ReactFlow>
|
||||
</WorkInProgress>
|
||||
);
|
||||
}
|
||||
export default {};
|
||||
|
||||
@@ -15,6 +15,7 @@ const WorkInProgress = (props: WorkInProgressProps) => {
|
||||
height: '100%',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
||||
43
invokeai/frontend/web/src/features/nodes/components/Flow.tsx
Normal file
43
invokeai/frontend/web/src/features/nodes/components/Flow.tsx
Normal file
@@ -0,0 +1,43 @@
|
||||
import {
|
||||
Background,
|
||||
Controls,
|
||||
MiniMap,
|
||||
OnEdgesChange,
|
||||
OnNodesChange,
|
||||
ReactFlow,
|
||||
} from 'reactflow';
|
||||
import { NODE_TYPES } from '../constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { edgesChanged, nodesChanged } from '../store/nodesSlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const Flow = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(changes) => dispatch(nodesChanged(changes)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
(changes) => dispatch(edgesChanged(changes)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={NODE_TYPES}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
>
|
||||
<Background />
|
||||
<Controls />
|
||||
<MiniMap nodeStrokeWidth={3} zoomable pannable />
|
||||
</ReactFlow>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,95 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Tooltip,
|
||||
Badge,
|
||||
HStack,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import {
|
||||
FIELDS,
|
||||
FIELD_NAMES,
|
||||
INVOCATIONS,
|
||||
INVOCATION_NAMES,
|
||||
} from '../constants';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { Flow } from './Flow';
|
||||
|
||||
const NodeEditor = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
dispatch(
|
||||
nodeAdded({
|
||||
id: uuidv4(),
|
||||
type: nodeType,
|
||||
position: { x: 0, y: 0 },
|
||||
data: {},
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
borderRadius: 'md',
|
||||
bg: 'base.850',
|
||||
}}
|
||||
>
|
||||
<Flow />
|
||||
<HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||
{FIELD_NAMES.map((field) => (
|
||||
<Badge
|
||||
key={field}
|
||||
colorScheme={FIELDS[field].color}
|
||||
sx={{ userSelect: 'none' }}
|
||||
>
|
||||
{field}
|
||||
</Badge>
|
||||
))}
|
||||
</HStack>
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label="Options"
|
||||
icon={<FaPlus />}
|
||||
sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
/>
|
||||
<MenuList>
|
||||
{INVOCATION_NAMES.map((name) => {
|
||||
const invocation = INVOCATIONS[name];
|
||||
return (
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={invocation.description}
|
||||
placement="end"
|
||||
hasArrow
|
||||
>
|
||||
<MenuItem onClick={() => addNode(invocation.title)}>
|
||||
{invocation.title}
|
||||
</MenuItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default NodeEditor;
|
||||
83
invokeai/frontend/web/src/features/nodes/constants.ts
Normal file
83
invokeai/frontend/web/src/features/nodes/constants.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { map, reduce } from 'lodash';
|
||||
import { NodeTypes } from 'reactflow';
|
||||
import { FieldConfig } from './types';
|
||||
import { buildInvocations } from './util/buildInvocations';
|
||||
|
||||
// here we fetch the schema, parse it and output all the constants
|
||||
|
||||
export const PRIMITIVE_FIELDS = [
|
||||
'integer',
|
||||
'number',
|
||||
'boolean',
|
||||
'string',
|
||||
'object',
|
||||
'array',
|
||||
];
|
||||
|
||||
export const AVAILABLE_COLORS = [
|
||||
'red',
|
||||
'orange',
|
||||
'yellow',
|
||||
'green',
|
||||
'teal',
|
||||
'blue',
|
||||
'cyan',
|
||||
'purple',
|
||||
'pink',
|
||||
];
|
||||
|
||||
export const { invocations: INVOCATIONS, customFields } =
|
||||
await buildInvocations();
|
||||
|
||||
export const INVOCATION_NAMES: (keyof typeof INVOCATIONS)[] = map(
|
||||
INVOCATIONS,
|
||||
(_, key) => key
|
||||
);
|
||||
|
||||
export const NODE_TYPES = reduce(
|
||||
INVOCATIONS,
|
||||
(acc, val, key) => {
|
||||
acc[val.title] = val.component;
|
||||
return acc;
|
||||
},
|
||||
{} as NodeTypes
|
||||
);
|
||||
|
||||
export const NODE_TYPE_NAMES: (keyof typeof NODE_TYPES)[] = map(
|
||||
NODE_TYPES,
|
||||
(_, key) => key
|
||||
);
|
||||
|
||||
console.log(customFields);
|
||||
|
||||
// all field types, maybe we can dynamically generate this in the future?
|
||||
export const FIELDS = [
|
||||
...PRIMITIVE_FIELDS,
|
||||
...customFields,
|
||||
].reduce<FieldConfig>((acc, val, i) => {
|
||||
let color = AVAILABLE_COLORS[i];
|
||||
if (!color) {
|
||||
color = 'gray';
|
||||
console.log('RAN OUTTA COLORS YO');
|
||||
}
|
||||
|
||||
acc[val] = {
|
||||
color,
|
||||
isPrimitive: PRIMITIVE_FIELDS.includes(val),
|
||||
};
|
||||
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// helper array of all field names
|
||||
export const FIELD_NAMES: (keyof typeof FIELDS)[] = map(
|
||||
FIELDS,
|
||||
(_, key) => key
|
||||
);
|
||||
|
||||
console.log('INVOCATIONS', INVOCATIONS);
|
||||
console.log('INVOCATION_NAMES', INVOCATION_NAMES);
|
||||
console.log('FIELDS', FIELDS);
|
||||
console.log('FIELD_NAMES', FIELD_NAMES);
|
||||
console.log('NODE_TYPES', NODE_TYPES);
|
||||
console.log('NODE_TYPE_NAMES', NODE_TYPE_NAMES);
|
||||
44
invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Normal file
44
invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import {
|
||||
addEdge,
|
||||
applyEdgeChanges,
|
||||
applyNodeChanges,
|
||||
Connection,
|
||||
Edge,
|
||||
EdgeChange,
|
||||
Node,
|
||||
NodeChange,
|
||||
} from 'reactflow';
|
||||
|
||||
export type NodesState = {
|
||||
nodes: Node[];
|
||||
edges: Edge[];
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
name: 'results',
|
||||
initialState: initialNodesState,
|
||||
reducers: {
|
||||
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
|
||||
state.nodes = applyNodeChanges(action.payload, state.nodes);
|
||||
},
|
||||
nodeAdded: (state, action: PayloadAction<Node>) => {
|
||||
state.nodes.push(action.payload);
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { nodesChanged, edgesChanged, nodeAdded } = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
83
invokeai/frontend/web/src/features/nodes/types.ts
Normal file
83
invokeai/frontend/web/src/features/nodes/types.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { FunctionComponent } from 'react';
|
||||
|
||||
export const isReferenceObject = (
|
||||
obj:
|
||||
| OpenAPIV3.ReferenceObject
|
||||
| OpenAPIV3.SchemaObject
|
||||
| NodeSchemaObject
|
||||
| ProcessedNodeSchemaObject
|
||||
): obj is OpenAPIV3.ReferenceObject => '$ref' in obj;
|
||||
|
||||
export const isNodeSchemaObject = (
|
||||
obj:
|
||||
| OpenAPIV3.ReferenceObject
|
||||
| OpenAPIV3.SchemaObject
|
||||
| NodeSchemaObject
|
||||
| ProcessedNodeSchemaObject
|
||||
): obj is NodeSchemaObject => !('$ref' in obj);
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3.ArraySchemaObject | OpenAPIV3.NonArraySchemaObject
|
||||
): obj is OpenAPIV3.ArraySchemaObject => 'items' in obj;
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3.ArraySchemaObject | OpenAPIV3.NonArraySchemaObject
|
||||
): obj is OpenAPIV3.NonArraySchemaObject => !('items' in obj);
|
||||
|
||||
// helper types - we have some guarantees about the schema - so we can override some optional
|
||||
// properties
|
||||
|
||||
export type RequiredInvocationProperties = {
|
||||
type: string;
|
||||
title: string;
|
||||
id: string;
|
||||
output: OpenAPIV3.ReferenceObject; // add the `output` custom schema prop
|
||||
properties: OpenAPIV3.ReferenceObject | NodeSchemaObject;
|
||||
};
|
||||
|
||||
export type NodeSchemaObject = Omit<
|
||||
OpenAPIV3.SchemaObject,
|
||||
keyof RequiredInvocationProperties
|
||||
> &
|
||||
RequiredInvocationProperties;
|
||||
|
||||
export type ProcessedNodeSchemaObject = NodeSchemaObject & {
|
||||
fieldType: string;
|
||||
};
|
||||
|
||||
export type NodesComponentsObject = Omit<
|
||||
OpenAPIV3.ComponentsObject,
|
||||
'schemas'
|
||||
> & {
|
||||
// we know we always have schemas
|
||||
schemas: {
|
||||
[key: string]:
|
||||
| OpenAPIV3.ReferenceObject
|
||||
| (NodeSchemaObject & { properties: { type: { default: string } } });
|
||||
};
|
||||
};
|
||||
|
||||
export type NodesOpenAPIDocument = Omit<OpenAPIV3.Document, 'components'> & {
|
||||
// we know we always have components
|
||||
components: NodesComponentsObject;
|
||||
};
|
||||
|
||||
export type Invocation = {
|
||||
title: string;
|
||||
type: string;
|
||||
description: string;
|
||||
schema: NodeSchemaObject;
|
||||
outputs: ProcessedNodeSchemaObject[];
|
||||
inputs: ProcessedNodeSchemaObject[];
|
||||
component: FunctionComponent;
|
||||
};
|
||||
|
||||
export type Invocations = { [name: string]: Invocation };
|
||||
|
||||
export type FieldConfig = {
|
||||
[type: string]: {
|
||||
color: string;
|
||||
isPrimitive: boolean;
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,44 @@
|
||||
import {
|
||||
Input,
|
||||
NumberDecrementStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
Select,
|
||||
Switch,
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactNode } from 'react';
|
||||
import { ProcessedNodeSchemaObject } from '../types';
|
||||
|
||||
// build an individual input element based on the schema
|
||||
export const buildFieldComponent = (
|
||||
field: ProcessedNodeSchemaObject
|
||||
): ReactNode => {
|
||||
if (field.fieldType === 'string') {
|
||||
// `string` fields may either be a text input or an enum ie select
|
||||
if (field.enum) {
|
||||
return (
|
||||
<Select defaultValue={field.default}>
|
||||
{field.enum?.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
|
||||
return <Input defaultValue={field.default}></Input>;
|
||||
} else if (field.fieldType === 'boolean') {
|
||||
return <Switch defaultValue={field.default}></Switch>;
|
||||
} else if (['integer', 'number'].includes(field.fieldType as string)) {
|
||||
return (
|
||||
<NumberInput defaultValue={field.default}>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,55 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
import { FIELDS } from '../constants';
|
||||
import { ProcessedNodeSchemaObject } from '../types';
|
||||
|
||||
export const buildInputHandleComponent = (field: ProcessedNodeSchemaObject) => {
|
||||
const color =
|
||||
field.fieldType in FIELDS ? FIELDS[field.fieldType].color : 'gray';
|
||||
return (
|
||||
<Tooltip
|
||||
key={field.title}
|
||||
label={field.fieldType}
|
||||
placement="start"
|
||||
hasArrow
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
id={field.title}
|
||||
position={Position.Left}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
left: '-1.5rem',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
backgroundColor: `var(--invokeai-colors-${color}-500)`,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export const buildOutputHandleComponent = (
|
||||
field: ProcessedNodeSchemaObject,
|
||||
top: string
|
||||
) => {
|
||||
const color =
|
||||
field.fieldType in FIELDS ? FIELDS[field.fieldType].color : 'gray';
|
||||
return (
|
||||
<Tooltip key={field.title} label={field.fieldType} placement="end" hasArrow>
|
||||
<Handle
|
||||
type="target"
|
||||
id={field.title}
|
||||
position={Position.Right}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top,
|
||||
right: '-0.5rem',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
backgroundColor: `var(--invokeai-colors-${color}-500)`,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,174 @@
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
HStack,
|
||||
Tooltip,
|
||||
Icon,
|
||||
} from '@chakra-ui/react';
|
||||
import { filter } from 'lodash';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import { PRIMITIVE_FIELDS } from '../constants';
|
||||
import {
|
||||
Invocations,
|
||||
isNodeSchemaObject,
|
||||
isReferenceObject,
|
||||
NodesOpenAPIDocument,
|
||||
ProcessedNodeSchemaObject,
|
||||
} from '../types';
|
||||
import { buildFieldComponent } from './buildFieldComponent';
|
||||
import {
|
||||
buildInputHandleComponent,
|
||||
buildOutputHandleComponent,
|
||||
} from './buildHandleComponent';
|
||||
import { fetchOpenAPISchema } from './fetchOpenAPISchema';
|
||||
import { parseOutputRef } from './parseRef';
|
||||
|
||||
// build object of invocations UI components keyed by their name
|
||||
export const buildInvocations = async (): Promise<{
|
||||
invocations: Invocations;
|
||||
customFields: string[];
|
||||
}> => {
|
||||
// get schema - cast as the modified OpenAPI document type
|
||||
const openApi = (await fetchOpenAPISchema()) as NodesOpenAPIDocument;
|
||||
|
||||
// filter out non-invocation schemas, kinda janky but dunno if there's another way really
|
||||
// also filter out some tricky ones for now
|
||||
const filteredSchemas = filter(
|
||||
openApi.components.schemas,
|
||||
(_schema, key) =>
|
||||
key.includes('Invocation') &&
|
||||
!key.includes('InvocationOutput') &&
|
||||
!key.includes('Collect') &&
|
||||
!key.includes('Range') &&
|
||||
!key.includes('Iterate') &&
|
||||
!key.includes('Graph')
|
||||
);
|
||||
|
||||
const customFields: string[] = [];
|
||||
|
||||
// actually build the UI components
|
||||
// reduce the array of schemas into an object of react function components, keyed by name (eg NodeTypes)
|
||||
const invocations = filteredSchemas.reduce<Invocations>(
|
||||
(acc, schema, key) => {
|
||||
// only want SchemaObjects
|
||||
if (isReferenceObject(schema)) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const title = schema.title.replace('Invocation', '');
|
||||
|
||||
const type = schema.properties.type.default;
|
||||
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
const inputs = filter(
|
||||
schema.properties,
|
||||
(prop, key) => !['type', 'id'].includes(key) && isNodeSchemaObject(prop)
|
||||
) as ProcessedNodeSchemaObject[]; // if i don't cast as, the type is never[], dunno why
|
||||
|
||||
inputs.forEach((input) => {
|
||||
if (input.allOf && isReferenceObject(input.allOf[0])) {
|
||||
input.fieldType = input.allOf[0].$ref.split('/').slice(-1)[0];
|
||||
} else {
|
||||
input.fieldType = input.type;
|
||||
}
|
||||
|
||||
if (
|
||||
!customFields.includes(input.fieldType) &&
|
||||
!PRIMITIVE_FIELDS.includes(input.fieldType)
|
||||
) {
|
||||
customFields.push(input.fieldType);
|
||||
}
|
||||
});
|
||||
|
||||
const outputs = [parseOutputRef(openApi.components, schema.output.$ref)];
|
||||
outputs.forEach(({ fieldType }) => {
|
||||
if (
|
||||
!customFields.includes(fieldType) &&
|
||||
!PRIMITIVE_FIELDS.includes(fieldType)
|
||||
) {
|
||||
customFields.push(fieldType);
|
||||
}
|
||||
});
|
||||
|
||||
// assemble!
|
||||
acc[title] = {
|
||||
title,
|
||||
type,
|
||||
schema,
|
||||
outputs,
|
||||
inputs,
|
||||
description: schema.description || '',
|
||||
component: () => (
|
||||
<Box
|
||||
sx={{
|
||||
padding: 4,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'md',
|
||||
boxShadow: 'dark-lg',
|
||||
}}
|
||||
>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<HStack justifyContent="space-between">
|
||||
<Heading size="sm" fontWeight={500} color="base.100">
|
||||
{title}
|
||||
</Heading>
|
||||
<Tooltip
|
||||
label={schema.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.300" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
{inputs.map((input, i) => {
|
||||
if (isNodeSchemaObject(input)) {
|
||||
return (
|
||||
<Box
|
||||
key={i}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
>
|
||||
<FormControl>
|
||||
<HStack
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
>
|
||||
<FormLabel>{input.title}</FormLabel>
|
||||
<Tooltip
|
||||
label={input.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.400" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
{buildFieldComponent(input)}
|
||||
</FormControl>
|
||||
{buildInputHandleComponent(input)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</Flex>
|
||||
{outputs.map((output, i) => {
|
||||
const top = `${(100 / (outputs.length + 1)) * (i + 1)}%`;
|
||||
return buildOutputHandleComponent(output, top);
|
||||
})}
|
||||
</Box>
|
||||
),
|
||||
};
|
||||
|
||||
return acc;
|
||||
},
|
||||
{}
|
||||
);
|
||||
|
||||
return { invocations, customFields };
|
||||
};
|
||||
@@ -0,0 +1,8 @@
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
|
||||
// grab the openapi schema json
|
||||
export async function fetchOpenAPISchema(): Promise<OpenAPIV3.Document> {
|
||||
const response = await fetch(`openapi.json`);
|
||||
const jsonData = await response.json();
|
||||
return jsonData;
|
||||
}
|
||||
38
invokeai/frontend/web/src/features/nodes/util/parseRef.ts
Normal file
38
invokeai/frontend/web/src/features/nodes/util/parseRef.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { filter } from 'lodash';
|
||||
import {
|
||||
isReferenceObject,
|
||||
NodeSchemaObject,
|
||||
NodesComponentsObject,
|
||||
ProcessedNodeSchemaObject,
|
||||
} from '../types';
|
||||
|
||||
export const parseOutputRef = (
|
||||
components: NodesComponentsObject,
|
||||
ref: string
|
||||
) => {
|
||||
// extract output schema name from ref
|
||||
const outputSchemaName = ref.split('/').slice(-1)[0].toString();
|
||||
|
||||
// TODO: recursively parse refs? currently just manually going one level deep
|
||||
const output = components.schemas[
|
||||
outputSchemaName
|
||||
] as unknown as ProcessedNodeSchemaObject;
|
||||
|
||||
const filteredProperties = filter(
|
||||
output.properties,
|
||||
(prop, key) => key !== 'type'
|
||||
) as NodeSchemaObject[];
|
||||
|
||||
if (filteredProperties[0]?.allOf?.length) {
|
||||
if (isReferenceObject(filteredProperties[0].allOf[0])) {
|
||||
output.fieldType = filteredProperties[0].allOf[0].$ref
|
||||
.split('/')
|
||||
.slice(-1)[0]
|
||||
.toString();
|
||||
}
|
||||
} else {
|
||||
output.fieldType = filteredProperties[0].type;
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
@@ -34,6 +34,7 @@ import UnifiedCanvasWorkarea from 'features/ui/components/tabs/UnifiedCanvas/Uni
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import NodeEditor from 'features/nodes/components/NodeEditor';
|
||||
|
||||
export interface InvokeTabInfo {
|
||||
id: InvokeTabName;
|
||||
@@ -65,7 +66,7 @@ const buildTabs = (disabledTabs: InvokeTabName[]): InvokeTabInfo[] => {
|
||||
{
|
||||
id: 'nodes',
|
||||
icon: <Icon as={MdDeviceHub} sx={tabIconStyles} />,
|
||||
workarea: <NodesWIP />,
|
||||
workarea: <NodeEditor />,
|
||||
},
|
||||
{
|
||||
id: 'postprocessing',
|
||||
|
||||
Reference in New Issue
Block a user