mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): add connection validation styling
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties } from 'react';
|
||||
import { Handle, Position, Connection, HandleType } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||
import { InputField, OutputField } from '../types';
|
||||
|
||||
const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: 0,
|
||||
};
|
||||
|
||||
const inputHandleStyles: CSSProperties = {
|
||||
left: '-1.7rem',
|
||||
};
|
||||
|
||||
const outputHandleStyles: CSSProperties = {
|
||||
right: '-0.5rem',
|
||||
};
|
||||
|
||||
type FieldHandleProps = {
|
||||
nodeId: string;
|
||||
field: InputField | OutputField;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
handleType: HandleType;
|
||||
styles?: CSSProperties;
|
||||
};
|
||||
|
||||
export const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
const connectionEventStyles = useConnectionEventStyles(
|
||||
nodeId,
|
||||
type,
|
||||
handleType
|
||||
);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={`${title} (${type})`}
|
||||
placement={handleType === 'target' ? 'start' : 'end'}
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type={handleType}
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={handleType === 'target' ? Position.Left : Position.Right}
|
||||
style={{
|
||||
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
|
||||
...styles,
|
||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||
...handleBaseStyles,
|
||||
...connectionEventStyles,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
@@ -8,11 +8,14 @@ import {
|
||||
ReactFlow,
|
||||
ConnectionLineType,
|
||||
OnConnectStart,
|
||||
OnConnectEnd,
|
||||
} from 'reactflow';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
} from '../store/nodesSlice';
|
||||
@@ -40,17 +43,26 @@ export const Flow = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnect: OnConnect = useCallback(
|
||||
(changes) => {
|
||||
console.log('connect');
|
||||
dispatch(connectionMade(changes));
|
||||
const onConnectStart: OnConnectStart = useCallback(
|
||||
(event, params) => {
|
||||
dispatch(connectionStarted(params));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectStart: OnConnectStart = useCallback((changes) => {
|
||||
console.log('connect start');
|
||||
}, []);
|
||||
const onConnect: OnConnect = useCallback(
|
||||
(connection) => {
|
||||
dispatch(connectionMade(connection));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectEnd: OnConnectEnd = useCallback(
|
||||
(event) => {
|
||||
dispatch(connectionEnded());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
@@ -59,8 +71,9 @@ export const Flow = () => {
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onConnect={onConnect}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
connectionLineType={ConnectionLineType.SmoothStep}
|
||||
defaultEdgeOptions={{ type: 'smoothstep' }}
|
||||
>
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { Handle, Position, Connection } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
import { InputField } from '../types';
|
||||
|
||||
type InputHandleProps = {
|
||||
nodeId: string;
|
||||
field: InputField;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
};
|
||||
|
||||
export const InputHandle = (props: InputHandleProps) => {
|
||||
const { nodeId, field, isValidConnection } = props;
|
||||
const { name, title, type, description } = field;
|
||||
return (
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={`${title} (${type})`}
|
||||
placement="start"
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={Position.Left}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
left: '-1.5rem',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Connection, Edge, NodeProps, useReactFlow } from 'reactflow';
|
||||
import { Connection, NodeProps, useReactFlow } from 'reactflow';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
@@ -12,8 +12,7 @@ import {
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import { Invocation } from '../types';
|
||||
import { InputFieldComponent } from './InputFieldComponent';
|
||||
import { OutputHandle } from './OutputHandle';
|
||||
import { InputHandle } from './InputHandle';
|
||||
import { FieldHandle } from './FieldHandle';
|
||||
import { map, size } from 'lodash';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
@@ -73,6 +72,8 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
|
||||
bg: 'base.800',
|
||||
borderRadius: 'md',
|
||||
boxShadow: 'dark-lg',
|
||||
borderWidth: 2,
|
||||
borderColor: selected ? 'base.400' : 'transparent',
|
||||
}}
|
||||
>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
@@ -113,28 +114,33 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
|
||||
</HStack>
|
||||
<InputFieldComponent nodeId={id} field={input} />
|
||||
</FormControl>
|
||||
<InputHandle
|
||||
<FieldHandle
|
||||
nodeId={id}
|
||||
field={input}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
</Flex>
|
||||
{map(outputs).map((output, i) => {
|
||||
const top = `${(100 / (size(outputs) + 1)) * (i + 1)}%`;
|
||||
return (
|
||||
<OutputHandle
|
||||
key={output.name}
|
||||
nodeId={id}
|
||||
field={output}
|
||||
isValidConnection={isValidConnection}
|
||||
top={top}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<Flex>
|
||||
{map(outputs).map((output, i) => {
|
||||
const top = `${(100 / (size(outputs) + 1)) * (i + 1)}%`;
|
||||
|
||||
return (
|
||||
<FieldHandle
|
||||
key={i}
|
||||
nodeId={id}
|
||||
field={output}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="source"
|
||||
styles={{ top }}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -58,7 +58,7 @@ const NodeEditor = () => {
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label="Options"
|
||||
aria-label="Add Node"
|
||||
icon={<FaPlus />}
|
||||
sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
/>
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { Handle, Position, Connection } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
import { OutputField } from '../types';
|
||||
|
||||
type OutputHandleProps = {
|
||||
nodeId: string;
|
||||
field: OutputField;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
top: string;
|
||||
};
|
||||
|
||||
export const OutputHandle = (props: OutputHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, top } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={`${title} (${type})`}
|
||||
placement="end"
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type="source"
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={Position.Right}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top,
|
||||
right: '-0.5rem',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,67 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { CSSProperties, useMemo } from 'react';
|
||||
import { HandleType, useReactFlow } from 'reactflow';
|
||||
import { FieldType, Invocation } from '../types';
|
||||
|
||||
const invalidTargetStyles: CSSProperties = {
|
||||
opacity: 0.3,
|
||||
};
|
||||
|
||||
const validTargetStyles: CSSProperties = {};
|
||||
|
||||
export const useConnectionEventStyles = (
|
||||
nodeId: string,
|
||||
fieldType: FieldType,
|
||||
handleType: HandleType
|
||||
) => {
|
||||
const flow = useReactFlow();
|
||||
const pendingConnection = useAppSelector(
|
||||
(state: RootState) => state.nodes.pendingConnection
|
||||
);
|
||||
|
||||
return useMemo(() => {
|
||||
if (!pendingConnection) {
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
handleId,
|
||||
handleType: sourceHandleType,
|
||||
nodeId: sourceNodeId,
|
||||
} = pendingConnection;
|
||||
|
||||
// default to connectable if these are not present - unsure why they ever would not be present...
|
||||
if (!handleId || !sourceNodeId || !handleType) {
|
||||
return validTargetStyles;
|
||||
}
|
||||
|
||||
if (
|
||||
// cannot connect a node's input to its own output
|
||||
nodeId === sourceNodeId
|
||||
) {
|
||||
return invalidTargetStyles;
|
||||
}
|
||||
|
||||
if (
|
||||
// cannot connect inputs to inputs or outputs to outputs
|
||||
handleType === sourceHandleType
|
||||
) {
|
||||
return invalidTargetStyles;
|
||||
}
|
||||
|
||||
const node = flow.getNode(sourceNodeId)?.data as Invocation;
|
||||
|
||||
// handle field types must be the same
|
||||
if (
|
||||
fieldType !==
|
||||
(sourceHandleType === 'target'
|
||||
? node.inputs[handleId].type
|
||||
: node.outputs[handleId].type)
|
||||
) {
|
||||
return invalidTargetStyles;
|
||||
}
|
||||
|
||||
return validTargetStyles;
|
||||
}, [pendingConnection, nodeId, flow, fieldType, handleType]);
|
||||
};
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
EdgeChange,
|
||||
Node,
|
||||
NodeChange,
|
||||
NodeTypes,
|
||||
OnConnectStartParams,
|
||||
} from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
import { Invocation } from '../types';
|
||||
@@ -20,6 +20,7 @@ export type NodesState = {
|
||||
edges: Edge[];
|
||||
schema: OpenAPIV3.Document | null;
|
||||
invocations: Record<string, Invocation>;
|
||||
pendingConnection: OnConnectStartParams | null;
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
@@ -27,6 +28,7 @@ export const initialNodesState: NodesState = {
|
||||
edges: [],
|
||||
schema: null,
|
||||
invocations: {},
|
||||
pendingConnection: null,
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
@@ -54,9 +56,15 @@ const nodesSlice = createSlice({
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.pendingConnection = action.payload;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
connectionEnded: (state) => {
|
||||
state.pendingConnection = null;
|
||||
},
|
||||
fieldValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -88,6 +96,8 @@ export const {
|
||||
nodeAdded,
|
||||
fieldValueChanged,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
||||
Reference in New Issue
Block a user