fix(ui): add basic node edges & connection validation

This commit is contained in:
psychedelicious
2023-04-09 15:43:12 +10:00
parent 48677ac10b
commit f4e2928ac3
11 changed files with 193 additions and 70 deletions

View File

@@ -0,0 +1,7 @@
# Node Editor Design
WIP
- on socket connect, if no schema loaded, fetch `localhost:9090/openapi.json`, saved to `state.nodes.schema`
- on fulfilled fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations`
- when you add a node, it gives it to `InvocationComponent.tsx`

View File

@@ -2,13 +2,20 @@ import {
Background,
Controls,
MiniMap,
OnConnect,
OnEdgesChange,
OnNodesChange,
ReactFlow,
ConnectionLineType,
OnConnectStart,
} from 'reactflow';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { edgesChanged, nodesChanged } from '../store/nodesSlice';
import {
connectionMade,
edgesChanged,
nodesChanged,
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { InvocationComponent } from './InvocationComponent';
@@ -20,15 +27,31 @@ export const Flow = () => {
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const onNodesChange: OnNodesChange = useCallback(
(changes) => dispatch(nodesChanged(changes)),
(changes) => {
dispatch(nodesChanged(changes));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => dispatch(edgesChanged(changes)),
(changes) => {
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(changes) => {
console.log('connect');
dispatch(connectionMade(changes));
},
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback((changes) => {
console.log('connect start');
}, []);
return (
<ReactFlow
nodeTypes={nodeTypes}
@@ -36,6 +59,10 @@ export const Flow = () => {
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnect={onConnect}
onConnectStart={onConnectStart}
connectionLineType={ConnectionLineType.SmoothStep}
defaultEdgeOptions={{ type: 'smoothstep' }}
>
<Background />
<Controls />

View File

@@ -1,3 +1,4 @@
import { Box } from '@chakra-ui/react';
import { InputField } from '../types';
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
@@ -39,4 +40,6 @@ export const InputFieldComponent = (props: InputFieldComponentProps) => {
if (type === 'latents') {
return <LatentsInputFieldComponent nodeId={nodeId} field={field} />;
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@@ -1,20 +1,29 @@
import { Tooltip } from '@chakra-ui/react';
import { Handle, Position } from 'reactflow';
import { FIELDS, InputField } from '../types';
import { Handle, Position, Connection } from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
import { InputField } from '../types';
type InputHandleProps = {
nodeId: string;
field: InputField;
isValidConnection: (connection: Connection) => boolean;
};
export const InputHandle = (props: InputHandleProps) => {
const { nodeId, field } = props;
const { nodeId, field, isValidConnection } = props;
const { name, title, type, description } = field;
return (
<Tooltip key={name} label={`${title} (${type})`} placement="start" hasArrow>
<Tooltip
key={name}
label={`${title} (${type})`}
placement="start"
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type="target"
id={name}
isValidConnection={isValidConnection}
position={Position.Left}
style={{
position: 'absolute',

View File

@@ -1,4 +1,4 @@
import { NodeProps } from 'reactflow';
import { Connection, Edge, NodeProps, useReactFlow } from 'reactflow';
import {
Box,
Flex,
@@ -15,11 +15,57 @@ import { InputFieldComponent } from './InputFieldComponent';
import { OutputHandle } from './OutputHandle';
import { InputHandle } from './InputHandle';
import { map, size } from 'lodash';
import { memo, useCallback } from 'react';
export const InvocationComponent = (props: NodeProps<Invocation>) => {
export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
const { id, data, selected } = props;
const { type, title, description, inputs, outputs } = data;
const flow = useReactFlow();
// Check if an in-progress connection is valid
const isValidConnection = useCallback(
(connection: Connection): boolean => {
const edges = flow.getEdges();
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
return (
edge.target === connection.target &&
edge.targetHandle === connection.targetHandle
);
})
) {
return false;
}
// Find the source and target nodes...
if (connection.source && connection.target) {
const sourceNode = flow.getNode(connection.source);
const targetNode = flow.getNode(connection.target);
// Conditional guards against undefined nodes/handles
if (
sourceNode &&
targetNode &&
connection.sourceHandle &&
connection.targetHandle
) {
// connection types must be the same for a connection
return (
sourceNode.data.outputs[connection.sourceHandle].type ===
targetNode.data.inputs[connection.targetHandle].type
);
}
}
// Default to invalid
return false;
},
[flow]
);
return (
<Box
sx={{
@@ -65,9 +111,13 @@ export const InvocationComponent = (props: NodeProps<Invocation>) => {
<Icon color="base.400" as={FaInfoCircle} />
</Tooltip>
</HStack>
{InputFieldComponent({ nodeId: id, field: input })}
<InputFieldComponent nodeId={id} field={input} />
</FormControl>
{InputHandle({ nodeId: id, field: input })}
<InputHandle
nodeId={id}
field={input}
isValidConnection={isValidConnection}
/>
</Box>
);
})}
@@ -75,8 +125,18 @@ export const InvocationComponent = (props: NodeProps<Invocation>) => {
</Flex>
{map(outputs).map((output, i) => {
const top = `${(100 / (size(outputs) + 1)) * (i + 1)}%`;
return OutputHandle({ nodeId: id, field: output, top });
return (
<OutputHandle
key={output.name}
nodeId={id}
field={output}
isValidConnection={isValidConnection}
top={top}
/>
);
})}
</Box>
);
};
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@@ -17,9 +17,9 @@ import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { Flow } from './Flow';
import { FIELDS } from '../types';
import { map } from 'lodash';
import { RootState } from 'app/store';
import { FIELDS } from '../constants';
const NodeEditor = () => {
const dispatch = useAppDispatch();

View File

@@ -1,21 +1,30 @@
import { Tooltip } from '@chakra-ui/react';
import { Handle, Position } from 'reactflow';
import { FIELDS, OutputField } from '../types';
import { Handle, Position, Connection } from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
import { OutputField } from '../types';
type OutputHandleProps = {
nodeId: string;
field: OutputField;
isValidConnection: (connection: Connection) => boolean;
top: string;
};
export const OutputHandle = (props: OutputHandleProps) => {
const { nodeId, field, top } = props;
const { nodeId, field, isValidConnection, top } = props;
const { name, title, type, description } = field;
return (
<Tooltip key={name} label={`${title} (${type})`} placement="end" hasArrow>
<Tooltip
label={`${title} (${type})`}
placement="end"
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type="target"
type="source"
id={name}
isValidConnection={isValidConnection}
position={Position.Right}
style={{
position: 'absolute',

View File

@@ -0,0 +1,51 @@
import { FieldType, FieldUIConfig } from './types';
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
integer: 'integer',
number: 'float',
string: 'string',
boolean: 'boolean',
enum: 'enum',
ImageField: 'image',
LatentsField: 'latents',
};
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
color: 'red',
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
},
float: {
color: 'orange',
title: 'Float',
description: 'Floats are numbers with a decimal point.',
},
string: {
color: 'yellow',
title: 'String',
description: 'Strings are text.',
},
boolean: {
color: 'green',
title: 'Boolean',
description: 'Booleans are true or false.',
},
enum: {
color: 'blue',
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
},
image: {
color: 'purple',
title: 'Image',
description: 'Images may be passed between nodes.',
},
latents: {
color: 'pink',
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
};

View File

@@ -82,7 +82,12 @@ const nodesSlice = createSlice({
},
});
export const { nodesChanged, edgesChanged, nodeAdded, fieldValueChanged } =
nodesSlice.actions;
export const {
nodesChanged,
edgesChanged,
nodeAdded,
fieldValueChanged,
connectionMade,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@@ -33,60 +33,12 @@ export type Invocation = {
// outputs: OutputField[];
};
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
integer: 'integer',
number: 'float',
string: 'string',
boolean: 'boolean',
enum: 'enum',
ImageField: 'image',
LatentsField: 'latents',
};
export type FieldUIConfig = {
color: 'red' | 'orange' | 'yellow' | 'green' | 'blue' | 'purple' | 'pink';
title: string;
description: string;
};
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
color: 'red',
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
},
float: {
color: 'orange',
title: 'Float',
description: 'Floats are numbers with a decimal point.',
},
string: {
color: 'yellow',
title: 'String',
description: 'Strings are text.',
},
boolean: {
color: 'green',
title: 'Boolean',
description: 'Booleans are true or false.',
},
enum: {
color: 'blue',
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
},
image: {
color: 'purple',
title: 'Image',
description: 'Images may be passed between nodes.',
},
latents: {
color: 'pink',
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
};
export type FieldType =
| 'integer'
| 'float'

View File

@@ -1,9 +1,9 @@
import { reduce } from 'lodash';
import { OpenAPIV3 } from 'openapi-types';
import { FIELD_TYPE_MAP } from '../constants';
import {
BooleanInputField,
EnumInputField,
FIELD_TYPE_MAP,
FloatInputField,
ImageInputField,
IntegerInputField,