mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix(ui): add basic node edges & connection validation
This commit is contained in:
7
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
7
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Node Editor Design
|
||||
|
||||
WIP
|
||||
|
||||
- on socket connect, if no schema loaded, fetch `localhost:9090/openapi.json`, saved to `state.nodes.schema`
|
||||
- on fulfilled fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations`
|
||||
- when you add a node, it gives it to `InvocationComponent.tsx`
|
||||
@@ -2,13 +2,20 @@ import {
|
||||
Background,
|
||||
Controls,
|
||||
MiniMap,
|
||||
OnConnect,
|
||||
OnEdgesChange,
|
||||
OnNodesChange,
|
||||
ReactFlow,
|
||||
ConnectionLineType,
|
||||
OnConnectStart,
|
||||
} from 'reactflow';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { edgesChanged, nodesChanged } from '../store/nodesSlice';
|
||||
import {
|
||||
connectionMade,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
} from '../store/nodesSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { InvocationComponent } from './InvocationComponent';
|
||||
|
||||
@@ -20,15 +27,31 @@ export const Flow = () => {
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(changes) => dispatch(nodesChanged(changes)),
|
||||
(changes) => {
|
||||
dispatch(nodesChanged(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
(changes) => dispatch(edgesChanged(changes)),
|
||||
(changes) => {
|
||||
dispatch(edgesChanged(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnect: OnConnect = useCallback(
|
||||
(changes) => {
|
||||
console.log('connect');
|
||||
dispatch(connectionMade(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectStart: OnConnectStart = useCallback((changes) => {
|
||||
console.log('connect start');
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
@@ -36,6 +59,10 @@ export const Flow = () => {
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onConnect={onConnect}
|
||||
onConnectStart={onConnectStart}
|
||||
connectionLineType={ConnectionLineType.SmoothStep}
|
||||
defaultEdgeOptions={{ type: 'smoothstep' }}
|
||||
>
|
||||
<Background />
|
||||
<Controls />
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { InputField } from '../types';
|
||||
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
|
||||
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
|
||||
@@ -39,4 +40,6 @@ export const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
if (type === 'latents') {
|
||||
return <LatentsInputFieldComponent nodeId={nodeId} field={field} />;
|
||||
}
|
||||
|
||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||
};
|
||||
|
||||
@@ -1,20 +1,29 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
import { FIELDS, InputField } from '../types';
|
||||
import { Handle, Position, Connection } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
import { InputField } from '../types';
|
||||
|
||||
type InputHandleProps = {
|
||||
nodeId: string;
|
||||
field: InputField;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
};
|
||||
|
||||
export const InputHandle = (props: InputHandleProps) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeId, field, isValidConnection } = props;
|
||||
const { name, title, type, description } = field;
|
||||
return (
|
||||
<Tooltip key={name} label={`${title} (${type})`} placement="start" hasArrow>
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={`${title} (${type})`}
|
||||
placement="start"
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={Position.Left}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { NodeProps } from 'reactflow';
|
||||
import { Connection, Edge, NodeProps, useReactFlow } from 'reactflow';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
@@ -15,11 +15,57 @@ import { InputFieldComponent } from './InputFieldComponent';
|
||||
import { OutputHandle } from './OutputHandle';
|
||||
import { InputHandle } from './InputHandle';
|
||||
import { map, size } from 'lodash';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
export const InvocationComponent = (props: NodeProps<Invocation>) => {
|
||||
export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
|
||||
const { id, data, selected } = props;
|
||||
const { type, title, description, inputs, outputs } = data;
|
||||
|
||||
const flow = useReactFlow();
|
||||
|
||||
// Check if an in-progress connection is valid
|
||||
const isValidConnection = useCallback(
|
||||
(connection: Connection): boolean => {
|
||||
const edges = flow.getEdges();
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return (
|
||||
edge.target === connection.target &&
|
||||
edge.targetHandle === connection.targetHandle
|
||||
);
|
||||
})
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Find the source and target nodes...
|
||||
if (connection.source && connection.target) {
|
||||
const sourceNode = flow.getNode(connection.source);
|
||||
const targetNode = flow.getNode(connection.target);
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (
|
||||
sourceNode &&
|
||||
targetNode &&
|
||||
connection.sourceHandle &&
|
||||
connection.targetHandle
|
||||
) {
|
||||
// connection types must be the same for a connection
|
||||
return (
|
||||
sourceNode.data.outputs[connection.sourceHandle].type ===
|
||||
targetNode.data.inputs[connection.targetHandle].type
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Default to invalid
|
||||
return false;
|
||||
},
|
||||
[flow]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@@ -65,9 +111,13 @@ export const InvocationComponent = (props: NodeProps<Invocation>) => {
|
||||
<Icon color="base.400" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
{InputFieldComponent({ nodeId: id, field: input })}
|
||||
<InputFieldComponent nodeId={id} field={input} />
|
||||
</FormControl>
|
||||
{InputHandle({ nodeId: id, field: input })}
|
||||
<InputHandle
|
||||
nodeId={id}
|
||||
field={input}
|
||||
isValidConnection={isValidConnection}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
@@ -75,8 +125,18 @@ export const InvocationComponent = (props: NodeProps<Invocation>) => {
|
||||
</Flex>
|
||||
{map(outputs).map((output, i) => {
|
||||
const top = `${(100 / (size(outputs) + 1)) * (i + 1)}%`;
|
||||
return OutputHandle({ nodeId: id, field: output, top });
|
||||
return (
|
||||
<OutputHandle
|
||||
key={output.name}
|
||||
nodeId={id}
|
||||
field={output}
|
||||
isValidConnection={isValidConnection}
|
||||
top={top}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
InvocationComponent.displayName = 'InvocationComponent';
|
||||
|
||||
@@ -17,9 +17,9 @@ import { FaPlus } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { Flow } from './Flow';
|
||||
import { FIELDS } from '../types';
|
||||
import { map } from 'lodash';
|
||||
import { RootState } from 'app/store';
|
||||
import { FIELDS } from '../constants';
|
||||
|
||||
const NodeEditor = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
import { FIELDS, OutputField } from '../types';
|
||||
import { Handle, Position, Connection } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
import { OutputField } from '../types';
|
||||
|
||||
type OutputHandleProps = {
|
||||
nodeId: string;
|
||||
field: OutputField;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
top: string;
|
||||
};
|
||||
|
||||
export const OutputHandle = (props: OutputHandleProps) => {
|
||||
const { nodeId, field, top } = props;
|
||||
const { nodeId, field, isValidConnection, top } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
return (
|
||||
<Tooltip key={name} label={`${title} (${type})`} placement="end" hasArrow>
|
||||
<Tooltip
|
||||
label={`${title} (${type})`}
|
||||
placement="end"
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
type="source"
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={Position.Right}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
|
||||
51
invokeai/frontend/web/src/features/nodes/constants.ts
Normal file
51
invokeai/frontend/web/src/features/nodes/constants.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import { FieldType, FieldUIConfig } from './types';
|
||||
|
||||
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
|
||||
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
integer: 'integer',
|
||||
number: 'float',
|
||||
string: 'string',
|
||||
boolean: 'boolean',
|
||||
enum: 'enum',
|
||||
ImageField: 'image',
|
||||
LatentsField: 'latents',
|
||||
};
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
color: 'red',
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
},
|
||||
float: {
|
||||
color: 'orange',
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
},
|
||||
string: {
|
||||
color: 'yellow',
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green',
|
||||
title: 'Boolean',
|
||||
description: 'Booleans are true or false.',
|
||||
},
|
||||
enum: {
|
||||
color: 'blue',
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
},
|
||||
image: {
|
||||
color: 'purple',
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
},
|
||||
latents: {
|
||||
color: 'pink',
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
},
|
||||
};
|
||||
@@ -82,7 +82,12 @@ const nodesSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { nodesChanged, edgesChanged, nodeAdded, fieldValueChanged } =
|
||||
nodesSlice.actions;
|
||||
export const {
|
||||
nodesChanged,
|
||||
edgesChanged,
|
||||
nodeAdded,
|
||||
fieldValueChanged,
|
||||
connectionMade,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
||||
@@ -33,60 +33,12 @@ export type Invocation = {
|
||||
// outputs: OutputField[];
|
||||
};
|
||||
|
||||
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
integer: 'integer',
|
||||
number: 'float',
|
||||
string: 'string',
|
||||
boolean: 'boolean',
|
||||
enum: 'enum',
|
||||
ImageField: 'image',
|
||||
LatentsField: 'latents',
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
color: 'red' | 'orange' | 'yellow' | 'green' | 'blue' | 'purple' | 'pink';
|
||||
title: string;
|
||||
description: string;
|
||||
};
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
color: 'red',
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
},
|
||||
float: {
|
||||
color: 'orange',
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
},
|
||||
string: {
|
||||
color: 'yellow',
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green',
|
||||
title: 'Boolean',
|
||||
description: 'Booleans are true or false.',
|
||||
},
|
||||
enum: {
|
||||
color: 'blue',
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
},
|
||||
image: {
|
||||
color: 'purple',
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
},
|
||||
latents: {
|
||||
color: 'pink',
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
},
|
||||
};
|
||||
|
||||
export type FieldType =
|
||||
| 'integer'
|
||||
| 'float'
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { reduce } from 'lodash';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { FIELD_TYPE_MAP } from '../constants';
|
||||
import {
|
||||
BooleanInputField,
|
||||
EnumInputField,
|
||||
FIELD_TYPE_MAP,
|
||||
FloatInputField,
|
||||
ImageInputField,
|
||||
IntegerInputField,
|
||||
|
||||
Reference in New Issue
Block a user