mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): validation connections w/ graphlib
This commit is contained in:
@@ -44,6 +44,7 @@
|
||||
"@chakra-ui/react": "^2.5.1",
|
||||
"@chakra-ui/styled-system": "^2.6.1",
|
||||
"@chakra-ui/theme-tools": "^2.0.16",
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@reduxjs/toolkit": "^1.9.3",
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties } from 'react';
|
||||
import { Handle, Position, Connection, HandleType } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||
import { InputField, OutputField } from '../types';
|
||||
|
||||
const handleBaseStyles: CSSProperties = {
|
||||
@@ -32,11 +32,13 @@ export const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
const connectionEventStyles = useConnectionEventStyles(
|
||||
nodeId,
|
||||
type,
|
||||
handleType
|
||||
);
|
||||
// this needs to iterate over every candicate target node, calculating graph cycles
|
||||
// WIP
|
||||
// const connectionEventStyles = useConnectionEventStyles(
|
||||
// nodeId,
|
||||
// type,
|
||||
// handleType
|
||||
// );
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
@@ -56,7 +58,7 @@ export const FieldHandle = (props: FieldHandleProps) => {
|
||||
...styles,
|
||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||
...handleBaseStyles,
|
||||
...connectionEventStyles,
|
||||
// ...connectionEventStyles,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
|
||||
@@ -83,9 +83,7 @@ export const Flow = () => {
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
connectionLineType={ConnectionLineType.SmoothStep}
|
||||
defaultEdgeOptions={{
|
||||
type: 'smoothstep',
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Connection, NodeProps, useReactFlow } from 'reactflow';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
@@ -15,56 +15,14 @@ import { Invocation } from '../types';
|
||||
import { InputFieldComponent } from './InputFieldComponent';
|
||||
import { FieldHandle } from './FieldHandle';
|
||||
import { map, size } from 'lodash';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useIsValidConnection } from '../hooks/useIsValidConnection';
|
||||
|
||||
export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
|
||||
const { id, data, selected } = props;
|
||||
const { type, title, description, inputs, outputs } = data;
|
||||
|
||||
const flow = useReactFlow();
|
||||
|
||||
// Check if an in-progress connection is valid
|
||||
const isValidConnection = useCallback(
|
||||
(connection: Connection): boolean => {
|
||||
const edges = flow.getEdges();
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return (
|
||||
edge.target === connection.target &&
|
||||
edge.targetHandle === connection.targetHandle
|
||||
);
|
||||
})
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Find the source and target nodes...
|
||||
if (connection.source && connection.target) {
|
||||
const sourceNode = flow.getNode(connection.source);
|
||||
const targetNode = flow.getNode(connection.target);
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (
|
||||
sourceNode &&
|
||||
targetNode &&
|
||||
connection.sourceHandle &&
|
||||
connection.targetHandle
|
||||
) {
|
||||
// connection types must be the same for a connection
|
||||
return (
|
||||
sourceNode.data.outputs[connection.sourceHandle].type ===
|
||||
targetNode.data.inputs[connection.targetHandle].type
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Default to invalid
|
||||
return false;
|
||||
},
|
||||
[flow]
|
||||
);
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
return (
|
||||
<Box
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { Box, HStack } from '@chakra-ui/react';
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { ReactFlowProvider } from 'reactflow';
|
||||
|
||||
import { Flow } from './Flow';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { buildNodesGraph } from '../util/buildNodesGraph';
|
||||
import { buildAdjacencyList } from '../util/isCyclic';
|
||||
|
||||
const NodeEditor = () => {
|
||||
const state = useAppSelector((state: RootState) => state);
|
||||
|
||||
const graph = buildNodesGraph(state);
|
||||
|
||||
const adjacencyList = buildAdjacencyList(state.nodes.edges);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@@ -23,7 +22,9 @@ const NodeEditor = () => {
|
||||
bg: 'base.850',
|
||||
}}
|
||||
>
|
||||
<Flow />
|
||||
<ReactFlowProvider>
|
||||
<Flow />
|
||||
</ReactFlowProvider>
|
||||
<Box
|
||||
as="pre"
|
||||
fontFamily="monospace"
|
||||
@@ -36,10 +37,7 @@ const NodeEditor = () => {
|
||||
pointerEvents="none"
|
||||
opacity={0.7}
|
||||
>
|
||||
<HStack alignItems={'flex-start'} justifyContent="space-between">
|
||||
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
|
||||
<Box w="50%">{JSON.stringify(adjacencyList, null, 2)}</Box>
|
||||
</HStack>
|
||||
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, useReactFlow } from 'reactflow';
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
|
||||
// Check if an in-progress connection is valid
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
const edges = flow.getEdges();
|
||||
const nodes = flow.getNodes();
|
||||
|
||||
// Connection must have valid targets
|
||||
if (!(source && sourceHandle && target && targetHandle)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = flow.getNode(source);
|
||||
const targetNode = flow.getNode(target);
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceNode && targetNode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection types must be the same for a connection
|
||||
if (
|
||||
sourceNode.data.outputs[sourceHandle].type !==
|
||||
targetNode.data.inputs[targetHandle].type
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
|
||||
// build a graphlib graph
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// Add the candidate edge to the graph
|
||||
g.setEdge(source, target);
|
||||
|
||||
return graphlib.alg.isAcyclic(g);
|
||||
},
|
||||
[flow]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
import collections
|
||||
|
||||
|
||||
class Graph(object):
|
||||
def __init__(self, edges):
|
||||
self.edges = edges
|
||||
self.adj = Graph._build_adjacency_list(edges)
|
||||
|
||||
@staticmethod
|
||||
def _build_adjacency_list(edges):
|
||||
adj = collections.defaultdict(list)
|
||||
for edge in edges:
|
||||
adj[edge[0]].append(edge[1])
|
||||
adj[edge[1]] # side effect only
|
||||
return adj
|
||||
|
||||
|
||||
def dfs(G):
|
||||
discovered = set()
|
||||
finished = set()
|
||||
|
||||
for u in G.adj:
|
||||
if u not in discovered and u not in finished:
|
||||
discovered, finished = dfs_visit(G, u, discovered, finished)
|
||||
|
||||
|
||||
def dfs_visit(G, u, discovered, finished):
|
||||
discovered.add(u)
|
||||
|
||||
for v in G.adj[u]:
|
||||
# Detect cycles
|
||||
if v in discovered:
|
||||
print(f"Cycle detected: found a back edge from {u} to {v}.")
|
||||
break
|
||||
|
||||
# Recurse into DFS tree
|
||||
if v not in finished:
|
||||
dfs_visit(G, v, discovered, finished)
|
||||
|
||||
discovered.remove(u)
|
||||
finished.add(u)
|
||||
|
||||
return discovered, finished
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
G = Graph([
|
||||
('u', 'v'),
|
||||
('u', 'x'),
|
||||
('v', 'y'),
|
||||
('w', 'y'),
|
||||
('w', 'z'),
|
||||
('x', 'v'),
|
||||
('y', 'x'),
|
||||
('z', 'z')])
|
||||
print(G.adj)
|
||||
|
||||
dfs(G)
|
||||
@@ -1,21 +0,0 @@
|
||||
import { Edge } from 'reactflow';
|
||||
|
||||
export type AdjacencyList = { [nodeId: string]: string[] };
|
||||
|
||||
export const buildAdjacencyList = (edges: Edge[]): AdjacencyList => {
|
||||
const adjacencyList: AdjacencyList = {};
|
||||
|
||||
edges.forEach((edge) => {
|
||||
if (!adjacencyList[edge.source]) {
|
||||
adjacencyList[edge.source] = [];
|
||||
}
|
||||
|
||||
if (!adjacencyList[edge.target]) {
|
||||
adjacencyList[edge.target] = [];
|
||||
}
|
||||
|
||||
adjacencyList[edge.source].push(edge.target);
|
||||
});
|
||||
|
||||
return adjacencyList;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -907,6 +907,11 @@
|
||||
dependencies:
|
||||
"@jridgewell/trace-mapping" "0.3.9"
|
||||
|
||||
"@dagrejs/graphlib@^2.1.12":
|
||||
version "2.1.12"
|
||||
resolved "https://registry.yarnpkg.com/@dagrejs/graphlib/-/graphlib-2.1.12.tgz#97d29eae006e4efcb68863505464e0e3f28fa5c7"
|
||||
integrity sha512-yHk2G7ZNzDEHhQTlYtbtEy5PqlIoioCxZUKcrlBgubMvrLmewXqSV3v4rhc8RAt5s8lr8PcWbiovEPuORxe2KA==
|
||||
|
||||
"@emotion/babel-plugin@^11.10.6":
|
||||
version "11.10.6"
|
||||
resolved "https://registry.yarnpkg.com/@emotion/babel-plugin/-/babel-plugin-11.10.6.tgz#a68ee4b019d661d6f37dec4b8903255766925ead"
|
||||
|
||||
Reference in New Issue
Block a user