mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): it blends
This commit is contained in:
@@ -0,0 +1,47 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
Tooltip,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { map } from 'lodash';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
export const AddNodeMenu = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const invocations = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocations
|
||||
);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
|
||||
},
|
||||
[dispatch, invocations]
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
|
||||
<MenuList>
|
||||
{map(invocations, ({ title, description, type }, key) => {
|
||||
return (
|
||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,18 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
|
||||
import { map } from 'lodash';
|
||||
import { FIELDS } from '../constants';
|
||||
|
||||
export const FieldTypeLegend = () => {
|
||||
return (
|
||||
<HStack>
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
<Tooltip key={key} label={description}>
|
||||
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
|
||||
{title}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
))}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
ConnectionLineType,
|
||||
OnConnectStart,
|
||||
OnConnectEnd,
|
||||
Panel,
|
||||
} from 'reactflow';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
@@ -21,6 +22,10 @@ import {
|
||||
} from '../store/nodesSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { InvocationComponent } from './InvocationComponent';
|
||||
import { AddNodeMenu } from './AddNodeMenu';
|
||||
import { FieldTypeLegend } from './FieldTypeLegend';
|
||||
import { Button } from '@chakra-ui/react';
|
||||
import { nodesGraphBuilt } from 'services/thunks/session';
|
||||
|
||||
const nodeTypes = { invocation: InvocationComponent };
|
||||
|
||||
@@ -64,6 +69,10 @@ export const Flow = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInvoke = useCallback(() => {
|
||||
dispatch(nodesGraphBuilt());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
@@ -80,6 +89,15 @@ export const Flow = () => {
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
>
|
||||
<Panel position="top-left">
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
<Panel position="top-center">
|
||||
<Button onClick={handleInvoke}>Will it blend?</Button>
|
||||
</Panel>
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
</Panel>
|
||||
<Background />
|
||||
<Controls />
|
||||
<MiniMap nodeStrokeWidth={3} zoomable pannable />
|
||||
|
||||
@@ -1,39 +1,14 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Tooltip,
|
||||
Badge,
|
||||
HStack,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { Flow } from './Flow';
|
||||
import { map } from 'lodash';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { FIELDS } from '../constants';
|
||||
import { buildNodesGraph } from '../util/buildNodesGraph';
|
||||
|
||||
const NodeEditor = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const state = useAppSelector((state: RootState) => state);
|
||||
|
||||
const invocations = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocations
|
||||
);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
|
||||
},
|
||||
[dispatch, invocations]
|
||||
);
|
||||
const graph = buildNodesGraph(state);
|
||||
|
||||
return (
|
||||
<Box
|
||||
@@ -46,32 +21,20 @@ const NodeEditor = () => {
|
||||
}}
|
||||
>
|
||||
<Flow />
|
||||
<HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
<Tooltip key={key} label={description}>
|
||||
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
|
||||
{title}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
))}
|
||||
</HStack>
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label="Add Node"
|
||||
icon={<FaPlus />}
|
||||
sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
/>
|
||||
<MenuList>
|
||||
{map(invocations, ({ title, description, type }, key) => {
|
||||
return (
|
||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
<Box
|
||||
as="pre"
|
||||
fontFamily="monospace"
|
||||
position="absolute"
|
||||
top={2}
|
||||
left={2}
|
||||
width="full"
|
||||
height="full"
|
||||
userSelect="none"
|
||||
pointerEvents="none"
|
||||
opacity={0.7}
|
||||
>
|
||||
{JSON.stringify(graph, undefined, 2)}
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import {
|
||||
addEdge,
|
||||
@@ -11,8 +11,15 @@ import {
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
} from 'reactflow';
|
||||
import { Graph } from 'services/api';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
import {
|
||||
isFulfilledAnyGraphBuilt,
|
||||
linearGraphBuilt,
|
||||
nodesGraphBuilt,
|
||||
} from 'services/thunks/session';
|
||||
import { Invocation } from '../types';
|
||||
import { buildNodesGraph } from '../util/buildNodesGraph';
|
||||
import { parseSchema } from '../util/parseSchema';
|
||||
|
||||
export type NodesState = {
|
||||
@@ -21,6 +28,7 @@ export type NodesState = {
|
||||
schema: OpenAPIV3.Document | null;
|
||||
invocations: Record<string, Invocation>;
|
||||
pendingConnection: OnConnectStartParams | null;
|
||||
lastGraph: Graph | null;
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
@@ -29,6 +37,7 @@ export const initialNodesState: NodesState = {
|
||||
schema: null,
|
||||
invocations: {},
|
||||
pendingConnection: null,
|
||||
lastGraph: null,
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
@@ -87,6 +96,10 @@ const nodesSlice = createSlice({
|
||||
state.schema = action.payload;
|
||||
state.invocations = parseSchema(action.payload);
|
||||
});
|
||||
|
||||
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
|
||||
state.lastGraph = action.payload;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { Edge, Node } from 'reactflow';
|
||||
import { Graph } from 'services/api';
|
||||
import { Invocation } from '../types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { reduce } from 'lodash';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
export const buildNodesGraph = (state: RootState): Graph => {
|
||||
const { nodes, edges } = state.nodes;
|
||||
|
||||
const parsedNodes = nodes.reduce<NonNullable<Graph['nodes']>>(
|
||||
(nodesAccumulator, node, nodeIndex) => {
|
||||
const { id, data } = node;
|
||||
const { type, inputs } = data;
|
||||
|
||||
const transformedInputs = reduce(
|
||||
inputs,
|
||||
(inputsAccumulator, input, name) => {
|
||||
inputsAccumulator[name] = input.value;
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, any>
|
||||
);
|
||||
|
||||
const graphNode = {
|
||||
type,
|
||||
id,
|
||||
...transformedInputs,
|
||||
};
|
||||
|
||||
nodesAccumulator[id] = graphNode;
|
||||
|
||||
return nodesAccumulator;
|
||||
},
|
||||
{}
|
||||
);
|
||||
|
||||
const parsedEdges = edges.reduce<NonNullable<Graph['edges']>>(
|
||||
(edgesAccumulator, edge, edgeIndex) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle as string,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle as string,
|
||||
},
|
||||
});
|
||||
|
||||
return edgesAccumulator;
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const graph = {
|
||||
id: uuidv4(),
|
||||
nodes: parsedNodes,
|
||||
edges: parsedEdges,
|
||||
};
|
||||
|
||||
return graph as Graph;
|
||||
};
|
||||
@@ -11,7 +11,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlay } from 'react-icons/fa';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { linearGraphBuilt, sessionCreated } from 'services/thunks/session';
|
||||
|
||||
interface InvokeButton
|
||||
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
|
||||
@@ -26,7 +26,7 @@ export default function InvokeButton(props: InvokeButton) {
|
||||
|
||||
const handleClickGenerate = () => {
|
||||
// dispatch(generateImage(activeTabName));
|
||||
dispatch(sessionCreated());
|
||||
dispatch(linearGraphBuilt());
|
||||
};
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -9,7 +9,7 @@ export type ProgressImage = {
|
||||
height: number;
|
||||
};
|
||||
|
||||
export type AnyInvocation = NonNullable<
|
||||
export type AnyInvocationType = NonNullable<
|
||||
NonNullable<Graph['nodes']>[string]['type']
|
||||
>;
|
||||
|
||||
@@ -22,7 +22,7 @@ export type AnyResult = GraphExecutionState['results'][string];
|
||||
*/
|
||||
export type GeneratorProgressEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation: AnyInvocation;
|
||||
invocation: AnyInvocationType;
|
||||
source_id: string;
|
||||
progress_image?: ProgressImage;
|
||||
step: number;
|
||||
@@ -49,7 +49,7 @@ export type InvocationCompleteEvent = {
|
||||
*/
|
||||
export type InvocationErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation: AnyInvocation;
|
||||
invocation: AnyInvocationType;
|
||||
source_id: string;
|
||||
error: string;
|
||||
};
|
||||
@@ -61,7 +61,7 @@ export type InvocationErrorEvent = {
|
||||
*/
|
||||
export type InvocationStartedEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation: AnyInvocation;
|
||||
invocation: AnyInvocationType;
|
||||
source_id: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,29 +1,52 @@
|
||||
import { createAppAsyncThunk } from 'app/storeUtils';
|
||||
import { SessionsService } from 'services/api';
|
||||
import { Graph, SessionsService } from 'services/api';
|
||||
import { buildGraph } from 'common/util/buildGraph';
|
||||
import { isFulfilled } from '@reduxjs/toolkit';
|
||||
import { isAnyOf, isFulfilled } from '@reduxjs/toolkit';
|
||||
import { subscribedNodeIdsSet } from 'features/system/store/systemSlice';
|
||||
import { buildNodesGraph } from 'features/nodes/util/buildNodesGraph';
|
||||
import { size } from 'lodash';
|
||||
|
||||
// type SessionCreatedArg = {
|
||||
// graph: Parameters<
|
||||
// (typeof SessionsService)['createSession']
|
||||
// >[0]['requestBody'];
|
||||
// nodeIdsToSubscribe?: string[];
|
||||
// };
|
||||
export const linearGraphBuilt = createAppAsyncThunk(
|
||||
'api/linearGraphBuilt',
|
||||
async (_, { dispatch, getState }) => {
|
||||
const graph = buildGraph(getState()).graph;
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
return graph;
|
||||
}
|
||||
);
|
||||
|
||||
export const nodesGraphBuilt = createAppAsyncThunk(
|
||||
'api/nodesGraphBuilt',
|
||||
async (_, { dispatch, getState }) => {
|
||||
const graph = buildNodesGraph(getState());
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
return graph;
|
||||
}
|
||||
);
|
||||
|
||||
export const isFulfilledAnyGraphBuilt = isAnyOf(
|
||||
linearGraphBuilt.fulfilled,
|
||||
nodesGraphBuilt.fulfilled
|
||||
);
|
||||
|
||||
type SessionCreatedArg = {
|
||||
graph: Parameters<
|
||||
(typeof SessionsService)['createSession']
|
||||
>[0]['requestBody'];
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.createSession()` thunk
|
||||
*/
|
||||
export const sessionCreated = createAppAsyncThunk(
|
||||
'api/sessionCreated',
|
||||
async (_arg, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const { graph, nodeIdsToSubscribe } = buildGraph(state);
|
||||
|
||||
dispatch(subscribedNodeIdsSet(nodeIdsToSubscribe));
|
||||
|
||||
async (arg: SessionCreatedArg, { dispatch, getState }) => {
|
||||
const response = await SessionsService.createSession({
|
||||
requestBody: graph,
|
||||
requestBody: arg.graph,
|
||||
});
|
||||
|
||||
return response;
|
||||
|
||||
Reference in New Issue
Block a user