feat(ui): it blends

This commit is contained in:
psychedelicious
2023-04-09 22:25:14 +10:00
parent ebc76a4785
commit 451fe7abcd
9 changed files with 226 additions and 78 deletions

View File

@@ -0,0 +1,47 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { map } from 'lodash';
import { RootState } from 'app/store';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocations = useAppSelector(
(state: RootState) => state.nodes.invocations
);
const addNode = useCallback(
(nodeType: string) => {
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
},
[dispatch, invocations]
);
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuList>
{map(invocations, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
);
};

View File

@@ -0,0 +1,18 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
import { map } from 'lodash';
import { FIELDS } from '../constants';
export const FieldTypeLegend = () => {
return (
<HStack>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
{title}
</Badge>
</Tooltip>
))}
</HStack>
);
};

View File

@@ -9,6 +9,7 @@ import {
ConnectionLineType,
OnConnectStart,
OnConnectEnd,
Panel,
} from 'reactflow';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
@@ -21,6 +22,10 @@ import {
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { InvocationComponent } from './InvocationComponent';
import { AddNodeMenu } from './AddNodeMenu';
import { FieldTypeLegend } from './FieldTypeLegend';
import { Button } from '@chakra-ui/react';
import { nodesGraphBuilt } from 'services/thunks/session';
const nodeTypes = { invocation: InvocationComponent };
@@ -64,6 +69,10 @@ export const Flow = () => {
[dispatch]
);
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
}, [dispatch]);
return (
<ReactFlow
nodeTypes={nodeTypes}
@@ -80,6 +89,15 @@ export const Flow = () => {
style: { strokeWidth: 2 },
}}
>
<Panel position="top-left">
<AddNodeMenu />
</Panel>
<Panel position="top-center">
<Button onClick={handleInvoke}>Will it blend?</Button>
</Panel>
<Panel position="top-right">
<FieldTypeLegend />
</Panel>
<Background />
<Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable />

View File

@@ -1,39 +1,14 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Box,
Tooltip,
Badge,
HStack,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { Box } from '@chakra-ui/react';
import { Flow } from './Flow';
import { map } from 'lodash';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { FIELDS } from '../constants';
import { buildNodesGraph } from '../util/buildNodesGraph';
const NodeEditor = () => {
const dispatch = useAppDispatch();
const state = useAppSelector((state: RootState) => state);
const invocations = useAppSelector(
(state: RootState) => state.nodes.invocations
);
const addNode = useCallback(
(nodeType: string) => {
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
},
[dispatch, invocations]
);
const graph = buildNodesGraph(state);
return (
<Box
@@ -46,32 +21,20 @@ const NodeEditor = () => {
}}
>
<Flow />
<HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
{title}
</Badge>
</Tooltip>
))}
</HStack>
<Menu>
<MenuButton
as={IconButton}
aria-label="Add Node"
icon={<FaPlus />}
sx={{ position: 'absolute', top: 2, left: 2 }}
/>
<MenuList>
{map(invocations, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={2}
left={2}
width="full"
height="full"
userSelect="none"
pointerEvents="none"
opacity={0.7}
>
{JSON.stringify(graph, undefined, 2)}
</Box>
</Box>
);
};

View File

@@ -1,4 +1,4 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf, PayloadAction } from '@reduxjs/toolkit';
import { OpenAPIV3 } from 'openapi-types';
import {
addEdge,
@@ -11,8 +11,15 @@ import {
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { Graph } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import {
isFulfilledAnyGraphBuilt,
linearGraphBuilt,
nodesGraphBuilt,
} from 'services/thunks/session';
import { Invocation } from '../types';
import { buildNodesGraph } from '../util/buildNodesGraph';
import { parseSchema } from '../util/parseSchema';
export type NodesState = {
@@ -21,6 +28,7 @@ export type NodesState = {
schema: OpenAPIV3.Document | null;
invocations: Record<string, Invocation>;
pendingConnection: OnConnectStartParams | null;
lastGraph: Graph | null;
};
export const initialNodesState: NodesState = {
@@ -29,6 +37,7 @@ export const initialNodesState: NodesState = {
schema: null,
invocations: {},
pendingConnection: null,
lastGraph: null,
};
const nodesSlice = createSlice({
@@ -87,6 +96,10 @@ const nodesSlice = createSlice({
state.schema = action.payload;
state.invocations = parseSchema(action.payload);
});
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
state.lastGraph = action.payload;
});
},
});

View File

@@ -0,0 +1,66 @@
import { Edge, Node } from 'reactflow';
import { Graph } from 'services/api';
import { Invocation } from '../types';
import { v4 as uuidv4 } from 'uuid';
import { reduce } from 'lodash';
import { RootState } from 'app/store';
export const buildNodesGraph = (state: RootState): Graph => {
const { nodes, edges } = state.nodes;
const parsedNodes = nodes.reduce<NonNullable<Graph['nodes']>>(
(nodesAccumulator, node, nodeIndex) => {
const { id, data } = node;
const { type, inputs } = data;
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
inputsAccumulator[name] = input.value;
return inputsAccumulator;
},
{} as Record<string, any>
);
const graphNode = {
type,
id,
...transformedInputs,
};
nodesAccumulator[id] = graphNode;
return nodesAccumulator;
},
{}
);
const parsedEdges = edges.reduce<NonNullable<Graph['edges']>>(
(edgesAccumulator, edge, edgeIndex) => {
const { source, target, sourceHandle, targetHandle } = edge;
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle as string,
},
destination: {
node_id: target,
field: targetHandle as string,
},
});
return edgesAccumulator;
},
[]
);
const graph = {
id: uuidv4(),
nodes: parsedNodes,
edges: parsedEdges,
};
return graph as Graph;
};

View File

@@ -11,7 +11,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
import { sessionCreated } from 'services/thunks/session';
import { linearGraphBuilt, sessionCreated } from 'services/thunks/session';
interface InvokeButton
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
@@ -26,7 +26,7 @@ export default function InvokeButton(props: InvokeButton) {
const handleClickGenerate = () => {
// dispatch(generateImage(activeTabName));
dispatch(sessionCreated());
dispatch(linearGraphBuilt());
};
const { t } = useTranslation();

View File

@@ -9,7 +9,7 @@ export type ProgressImage = {
height: number;
};
export type AnyInvocation = NonNullable<
export type AnyInvocationType = NonNullable<
NonNullable<Graph['nodes']>[string]['type']
>;
@@ -22,7 +22,7 @@ export type AnyResult = GraphExecutionState['results'][string];
*/
export type GeneratorProgressEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
invocation: AnyInvocationType;
source_id: string;
progress_image?: ProgressImage;
step: number;
@@ -49,7 +49,7 @@ export type InvocationCompleteEvent = {
*/
export type InvocationErrorEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
invocation: AnyInvocationType;
source_id: string;
error: string;
};
@@ -61,7 +61,7 @@ export type InvocationErrorEvent = {
*/
export type InvocationStartedEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
invocation: AnyInvocationType;
source_id: string;
};

View File

@@ -1,29 +1,52 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { SessionsService } from 'services/api';
import { Graph, SessionsService } from 'services/api';
import { buildGraph } from 'common/util/buildGraph';
import { isFulfilled } from '@reduxjs/toolkit';
import { isAnyOf, isFulfilled } from '@reduxjs/toolkit';
import { subscribedNodeIdsSet } from 'features/system/store/systemSlice';
import { buildNodesGraph } from 'features/nodes/util/buildNodesGraph';
import { size } from 'lodash';
// type SessionCreatedArg = {
// graph: Parameters<
// (typeof SessionsService)['createSession']
// >[0]['requestBody'];
// nodeIdsToSubscribe?: string[];
// };
export const linearGraphBuilt = createAppAsyncThunk(
'api/linearGraphBuilt',
async (_, { dispatch, getState }) => {
const graph = buildGraph(getState()).graph;
dispatch(sessionCreated({ graph }));
return graph;
}
);
export const nodesGraphBuilt = createAppAsyncThunk(
'api/nodesGraphBuilt',
async (_, { dispatch, getState }) => {
const graph = buildNodesGraph(getState());
dispatch(sessionCreated({ graph }));
return graph;
}
);
export const isFulfilledAnyGraphBuilt = isAnyOf(
linearGraphBuilt.fulfilled,
nodesGraphBuilt.fulfilled
);
type SessionCreatedArg = {
graph: Parameters<
(typeof SessionsService)['createSession']
>[0]['requestBody'];
};
/**
* `SessionsService.createSession()` thunk
*/
export const sessionCreated = createAppAsyncThunk(
'api/sessionCreated',
async (_arg, { dispatch, getState }) => {
const state = getState();
const { graph, nodeIdsToSubscribe } = buildGraph(state);
dispatch(subscribedNodeIdsSet(nodeIdsToSubscribe));
async (arg: SessionCreatedArg, { dispatch, getState }) => {
const response = await SessionsService.createSession({
requestBody: graph,
requestBody: arg.graph,
});
return response;