mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): migrate add node popover to cmdk
Put this together as a way to figure out the library before moving on to the full app cmdk. Works great.
This commit is contained in:
@@ -2,6 +2,7 @@ import 'reactflow/dist/style.css';
|
||||
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AddNodeCmdk } from 'features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk';
|
||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||
import { LoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
|
||||
import { SaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/SaveWorkflowAsDialog';
|
||||
@@ -10,7 +11,6 @@ import { useTranslation } from 'react-i18next';
|
||||
import { MdDeviceHub } from 'react-icons/md';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
import AddNodePopover from './flow/AddNodePopover/AddNodePopover';
|
||||
import { Flow } from './flow/Flow';
|
||||
import BottomLeftPanel from './flow/panels/BottomLeftPanel/BottomLeftPanel';
|
||||
import MinimapPanel from './flow/panels/MinimapPanel/MinimapPanel';
|
||||
@@ -31,7 +31,7 @@ const NodeEditor = () => {
|
||||
{data && (
|
||||
<>
|
||||
<Flow />
|
||||
<AddNodePopover />
|
||||
<AddNodeCmdk />
|
||||
<TopPanel />
|
||||
<BottomLeftPanel />
|
||||
<MinimapPanel />
|
||||
|
||||
@@ -0,0 +1,420 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
Icon,
|
||||
Input,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalContent,
|
||||
ModalOverlay,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$cursorPos,
|
||||
$edgePendingUpdate,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
useAddNodeCmdk,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memoize } from 'lodash-es';
|
||||
import { computed } from 'nanostores';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFlaskBold, PiHammerBold } from 'react-icons/pi';
|
||||
import type { EdgeChange, NodeChange } from 'reactflow';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
const useThrottle = <T,>(value: T, limit: number) => {
|
||||
const [throttledValue, setThrottledValue] = useState(value);
|
||||
const lastRan = useRef(Date.now());
|
||||
|
||||
useEffect(() => {
|
||||
const handler = setTimeout(
|
||||
function () {
|
||||
if (Date.now() - lastRan.current >= limit) {
|
||||
setThrottledValue(value);
|
||||
lastRan.current = Date.now();
|
||||
}
|
||||
},
|
||||
limit - (Date.now() - lastRan.current)
|
||||
);
|
||||
|
||||
return () => {
|
||||
clearTimeout(handler);
|
||||
};
|
||||
}, [value, limit]);
|
||||
|
||||
return throttledValue;
|
||||
};
|
||||
|
||||
const useAddNode = () => {
|
||||
const { t } = useTranslation();
|
||||
const store = useAppStore();
|
||||
const buildInvocation = useBuildNode();
|
||||
const templates = useStore($templates);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string): void => {
|
||||
const node = buildInvocation(nodeType);
|
||||
if (!node) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
nodeType: nodeType,
|
||||
});
|
||||
toast({
|
||||
status: 'error',
|
||||
title: errorMessage,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Find a cozy spot for the node
|
||||
const cursorPos = $cursorPos.get();
|
||||
const { nodes, edges } = selectNodesSlice(store.getState());
|
||||
node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y);
|
||||
node.selected = true;
|
||||
|
||||
// Deselect all other nodes and edges
|
||||
const nodeChanges: NodeChange[] = [{ type: 'add', item: node }];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
nodes.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
nodeChanges.push({ type: 'select', id, selected: false });
|
||||
}
|
||||
});
|
||||
edges.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
edgeChanges.push({ type: 'select', id, selected: false });
|
||||
}
|
||||
});
|
||||
|
||||
// Onwards!
|
||||
if (nodeChanges.length > 0) {
|
||||
store.dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
if (edgeChanges.length > 0) {
|
||||
store.dispatch(edgesChanged(edgeChanges));
|
||||
}
|
||||
|
||||
// Auto-connect an edge if we just added a node and have a pending connection
|
||||
if (pendingConnection && isInvocationNode(node)) {
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const { handleType } = pendingConnection;
|
||||
|
||||
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
|
||||
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
|
||||
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||
|
||||
const { nodes, edges } = selectNodesSlice(store.getState());
|
||||
const connection = getFirstValidConnection(
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
if (connection) {
|
||||
const newEdge = connectionToEdge(connection);
|
||||
store.dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
|
||||
}
|
||||
}
|
||||
},
|
||||
[buildInvocation, pendingConnection, store, t, templates]
|
||||
);
|
||||
|
||||
return addNode;
|
||||
};
|
||||
|
||||
const cmdkRootSx: SystemStyleObject = {
|
||||
'[cmdk-root]': {
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
},
|
||||
'[cmdk-list]': {
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
},
|
||||
};
|
||||
|
||||
export const AddNodeCmdk = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const addNodeCmdk = useAddNodeCmdk();
|
||||
const addNodeCmdkIsOpen = useStore(addNodeCmdk.$boolean);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const addNode = useAddNode();
|
||||
const throttledSearchTerm = useThrottle(searchTerm, 100);
|
||||
|
||||
useHotkeys(['shift+a', 'space'], addNodeCmdk.setTrue, { preventDefault: true });
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
}, []);
|
||||
|
||||
const onSelect = useCallback(
|
||||
(value: string) => {
|
||||
addNode(value);
|
||||
$addNodeCmdk.set(false);
|
||||
setSearchTerm('');
|
||||
},
|
||||
[addNode]
|
||||
);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
addNodeCmdk.setFalse();
|
||||
setSearchTerm('');
|
||||
$pendingConnection.set(null);
|
||||
}, [addNodeCmdk]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen={addNodeCmdkIsOpen}
|
||||
onClose={onClose}
|
||||
useInert={false}
|
||||
initialFocusRef={inputRef}
|
||||
size="xl"
|
||||
isCentered
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent h="512" maxH="70%">
|
||||
<ModalBody p={2} h="full" sx={cmdkRootSx}>
|
||||
<CommandRoot loop shouldFilter={false}>
|
||||
<Flex flexDir="column" h="full" gap={2}>
|
||||
<Input ref={inputRef} value={searchTerm} onChange={onChange} placeholder={t('nodes.nodeSearch')} />
|
||||
<Box w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<CommandEmpty>
|
||||
<IAINoContentFallback
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
icon={null}
|
||||
label="No matching items"
|
||||
/>
|
||||
</CommandEmpty>
|
||||
<CommandList>
|
||||
<NodeCommandList searchTerm={throttledSearchTerm} onSelect={onSelect} />
|
||||
</CommandList>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
</Flex>
|
||||
</CommandRoot>
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
AddNodeCmdk.displayName = 'AddNodeCmdk';
|
||||
|
||||
const cmdkItemSx: SystemStyleObject = {
|
||||
'&[data-selected="true"]': {
|
||||
bg: 'base.700',
|
||||
},
|
||||
};
|
||||
|
||||
type NodeCommandItemData = {
|
||||
value: string;
|
||||
label: string;
|
||||
description: string;
|
||||
classification: S['Classification'];
|
||||
nodePack: string;
|
||||
};
|
||||
|
||||
const $templatesArray = computed($templates, (templates) => Object.values(templates));
|
||||
|
||||
const createRegex = memoize(
|
||||
(inputValue: string) =>
|
||||
new RegExp(
|
||||
inputValue
|
||||
.trim()
|
||||
.replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '')
|
||||
.split(' ')
|
||||
.join('.*'),
|
||||
'gi'
|
||||
)
|
||||
);
|
||||
|
||||
// Filterable items are a subset of Invocation template - we also want to filter for notes or current image node,
|
||||
// so we are using a less specific type instead of `InvocationTemplate`
|
||||
type FilterableItem = {
|
||||
type: string;
|
||||
title: string;
|
||||
description: string;
|
||||
tags: string[];
|
||||
classification: S['Classification'];
|
||||
nodePack: string;
|
||||
};
|
||||
|
||||
const filter = memoize(
|
||||
(item: FilterableItem, searchTerm: string) => {
|
||||
const regex = createRegex(searchTerm);
|
||||
|
||||
if (!searchTerm) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (item.title.includes(searchTerm) || regex.test(item.title)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (item.type.includes(searchTerm) || regex.test(item.type)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (item.description.includes(searchTerm) || regex.test(item.description)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (item.nodePack.includes(searchTerm) || regex.test(item.nodePack)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (item.classification.includes(searchTerm) || regex.test(item.classification)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const tag of item.tags) {
|
||||
if (tag.includes(searchTerm) || regex.test(tag)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
},
|
||||
(item: FilterableItem, searchTerm: string) => `${item.type}-${searchTerm}`
|
||||
);
|
||||
|
||||
const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; onSelect: (value: string) => void }) => {
|
||||
const { t } = useTranslation();
|
||||
const templatesArray = useStore($templatesArray);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const currentImageFilterItem = useMemo<FilterableItem>(
|
||||
() => ({
|
||||
type: 'current_image',
|
||||
title: t('nodes.currentImage'),
|
||||
description: t('nodes.currentImageDescription'),
|
||||
tags: ['progress', 'image', 'current'],
|
||||
classification: 'stable',
|
||||
nodePack: 'invokeai',
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
const notesFilterItem = useMemo<FilterableItem>(
|
||||
() => ({
|
||||
type: 'notes',
|
||||
title: t('nodes.notes'),
|
||||
description: t('nodes.notesDescription'),
|
||||
tags: ['notes'],
|
||||
classification: 'stable',
|
||||
nodePack: 'invokeai',
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
||||
const items = useMemo<NodeCommandItemData[]>(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const _items: NodeCommandItemData[] = [];
|
||||
|
||||
if (!pendingConnection) {
|
||||
for (const template of templatesArray) {
|
||||
if (filter(template, searchTerm)) {
|
||||
_items.push({
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
description: template.description,
|
||||
classification: template.classification,
|
||||
nodePack: template.nodePack,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const item of [currentImageFilterItem, notesFilterItem]) {
|
||||
if (filter(item, searchTerm)) {
|
||||
_items.push({
|
||||
label: item.title,
|
||||
value: item.type,
|
||||
description: item.description,
|
||||
classification: item.classification,
|
||||
nodePack: item.nodePack,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const template of templatesArray) {
|
||||
if (filter(template, searchTerm)) {
|
||||
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
for (const field of Object.values(candidateFields)) {
|
||||
const sourceType =
|
||||
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType =
|
||||
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
|
||||
if (validateConnectionTypes(sourceType, targetType)) {
|
||||
_items.push({
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
description: template.description,
|
||||
classification: template.classification,
|
||||
nodePack: template.nodePack,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return _items;
|
||||
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{items.map((item) => (
|
||||
<CommandItem key={item.value} value={item.value} onSelect={onSelect} asChild>
|
||||
<Flex role="button" flexDir="column" sx={cmdkItemSx} py={1} px={2} borderRadius="base">
|
||||
<Flex alignItems="center" gap={2}>
|
||||
{item.classification === 'beta' && <Icon boxSize={4} color="invokeYellow.300" as={PiHammerBold} />}
|
||||
{item.classification === 'prototype' && <Icon boxSize={4} color="invokeRed.300" as={PiFlaskBold} />}
|
||||
<Text fontWeight="semibold">{item.label}</Text>
|
||||
<Spacer />
|
||||
<Text variant="subtext" fontWeight="semibold">
|
||||
{item.nodePack}
|
||||
</Text>
|
||||
</Flex>
|
||||
{item.description && <Text color="base.200">{item.description}</Text>}
|
||||
</Flex>
|
||||
</CommandItem>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
NodeCommandList.displayName = 'CommandListItems';
|
||||
@@ -1,267 +0,0 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
|
||||
import type { SelectInstance } from 'chakra-react-select';
|
||||
import { INTERACTION_SCOPES } from 'common/hooks/interactionScopes';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import {
|
||||
$cursorPos,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
closeAddNodePopover,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
openAddNodePopover,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||
import type { EdgeChange, NodeChange } from 'reactflow';
|
||||
|
||||
const createRegex = memoize(
|
||||
(inputValue: string) =>
|
||||
new RegExp(
|
||||
inputValue
|
||||
.trim()
|
||||
.replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '')
|
||||
.split(' ')
|
||||
.join('.*'),
|
||||
'gi'
|
||||
)
|
||||
);
|
||||
|
||||
const filterOption = memoize((option: FilterOptionOption<ComboboxOption>, inputValue: string) => {
|
||||
if (!inputValue) {
|
||||
return true;
|
||||
}
|
||||
const regex = createRegex(inputValue);
|
||||
return (
|
||||
regex.test(option.label) ||
|
||||
regex.test(option.data.description ?? '') ||
|
||||
(option.data.tags ?? []).some((tag) => regex.test(tag))
|
||||
);
|
||||
});
|
||||
|
||||
const AddNodePopover = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const buildInvocation = useBuildNode();
|
||||
const { t } = useTranslation();
|
||||
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const templates = useStore($templates);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const isOpen = useStore($isAddNodePopoverOpen);
|
||||
const store = useAppStore();
|
||||
const isWorkflowsActive = useStore(INTERACTION_SCOPES.workflows.$isActive);
|
||||
|
||||
const filteredTemplates = useMemo(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const templatesArray = map(templates);
|
||||
if (!pendingConnection) {
|
||||
return templatesArray;
|
||||
}
|
||||
|
||||
return filter(templates, (template) => {
|
||||
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
|
||||
return some(candidateFields, (field) => {
|
||||
const sourceType =
|
||||
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType =
|
||||
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateConnectionTypes(sourceType, targetType);
|
||||
});
|
||||
});
|
||||
}, [templates, pendingConnection]);
|
||||
|
||||
const options = useMemo(() => {
|
||||
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
|
||||
return {
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
description: template.description,
|
||||
tags: template.tags,
|
||||
};
|
||||
});
|
||||
|
||||
//We only want these nodes if we're not filtered
|
||||
if (!pendingConnection) {
|
||||
_options.push({
|
||||
label: t('nodes.currentImage'),
|
||||
value: 'current_image',
|
||||
description: t('nodes.currentImageDescription'),
|
||||
tags: ['progress'],
|
||||
});
|
||||
|
||||
_options.push({
|
||||
label: t('nodes.notes'),
|
||||
value: 'notes',
|
||||
description: t('nodes.notesDescription'),
|
||||
tags: ['notes'],
|
||||
});
|
||||
}
|
||||
|
||||
_options.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return _options;
|
||||
}, [filteredTemplates, pendingConnection, t]);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string): AnyNode | null => {
|
||||
const node = buildInvocation(nodeType);
|
||||
if (!node) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
nodeType: nodeType,
|
||||
});
|
||||
toast({
|
||||
status: 'error',
|
||||
title: errorMessage,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
// Find a cozy spot for the node
|
||||
const cursorPos = $cursorPos.get();
|
||||
const { nodes, edges } = selectNodesSlice(store.getState());
|
||||
node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y);
|
||||
node.selected = true;
|
||||
|
||||
// Deselect all other nodes and edges
|
||||
const nodeChanges: NodeChange[] = [{ type: 'add', item: node }];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
nodes.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
nodeChanges.push({ type: 'select', id, selected: false });
|
||||
}
|
||||
});
|
||||
edges.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
edgeChanges.push({ type: 'select', id, selected: false });
|
||||
}
|
||||
});
|
||||
|
||||
// Onwards!
|
||||
if (nodeChanges.length > 0) {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
if (edgeChanges.length > 0) {
|
||||
dispatch(edgesChanged(edgeChanges));
|
||||
}
|
||||
return node;
|
||||
},
|
||||
[buildInvocation, store, dispatch, t]
|
||||
);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
const node = addNode(v.value);
|
||||
|
||||
// Auto-connect an edge if we just added a node and have a pending connection
|
||||
if (pendingConnection && isInvocationNode(node)) {
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const { handleType } = pendingConnection;
|
||||
|
||||
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
|
||||
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
|
||||
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||
|
||||
const { nodes, edges } = selectNodesSlice(store.getState());
|
||||
const connection = getFirstValidConnection(
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
if (connection) {
|
||||
const newEdge = connectionToEdge(connection);
|
||||
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
|
||||
}
|
||||
}
|
||||
|
||||
closeAddNodePopover();
|
||||
},
|
||||
[addNode, dispatch, pendingConnection, store, templates]
|
||||
);
|
||||
|
||||
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
|
||||
if (!$isAddNodePopoverOpen.get()) {
|
||||
e.preventDefault();
|
||||
openAddNodePopover();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
useHotkeys(['shift+a', 'space'], handleHotkeyOpen, { enabled: isWorkflowsActive }, [isWorkflowsActive]);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
isOpen={isOpen}
|
||||
onClose={closeAddNodePopover}
|
||||
placement="bottom"
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
initialFocusRef={inputRef}
|
||||
isLazy
|
||||
>
|
||||
<PopoverAnchor>
|
||||
<Flex position="absolute" top="15%" insetInlineStart="50%" pointerEvents="none" />
|
||||
</PopoverAnchor>
|
||||
<PopoverContent
|
||||
p={0}
|
||||
top={-1}
|
||||
shadow="dark-lg"
|
||||
borderColor="invokeBlue.400"
|
||||
borderWidth="2px"
|
||||
borderStyle="solid"
|
||||
>
|
||||
<PopoverBody w="32rem" p={0}>
|
||||
<Combobox
|
||||
menuIsOpen={isOpen}
|
||||
selectRef={selectRef}
|
||||
value={null}
|
||||
placeholder={t('nodes.nodeSearch')}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
filterOption={filterOption}
|
||||
onChange={onChange}
|
||||
onMenuClose={closeAddNodePopover}
|
||||
inputRef={inputRef}
|
||||
closeMenuOnSelect={false}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddNodePopover);
|
||||
@@ -8,10 +8,10 @@ import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$cursorPos,
|
||||
$didUpdateEdge,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$lastEdgeUpdateMouseEvent,
|
||||
$pendingConnection,
|
||||
$viewport,
|
||||
@@ -281,7 +281,7 @@ export const Flow = memo(() => {
|
||||
const onEscapeHotkey = useCallback(() => {
|
||||
if (!$edgePendingUpdate.get()) {
|
||||
$pendingConnection.set(null);
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
$addNodeCmdk.set(false);
|
||||
cancelConnection();
|
||||
}
|
||||
}, [cancelConnection]);
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { openAddNodePopover } from 'features/nodes/store/nodesSlice';
|
||||
import { useAddNodeCmdk } from 'features/nodes/store/nodesSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
const AddNodeButton = () => {
|
||||
const addNodeCmdk = useAddNodeCmdk();
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
@@ -12,7 +13,7 @@ const AddNodeButton = () => {
|
||||
tooltip={t('nodes.addNodeToolTip')}
|
||||
aria-label={t('nodes.addNode')}
|
||||
icon={<PiPlusBold />}
|
||||
onClick={openAddNodePopover}
|
||||
onClick={addNodeCmdk.setTrue}
|
||||
pointerEvents="auto"
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -2,9 +2,9 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$didUpdateEdge,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
edgesChanged,
|
||||
@@ -107,7 +107,7 @@ export const useConnection = () => {
|
||||
$pendingConnection.set(null);
|
||||
} else {
|
||||
// The mouse is not over a node - we should open the add node popover
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
$addNodeCmdk.set(true);
|
||||
}
|
||||
}, [store, templates, updateNodeInternals]);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
@@ -443,14 +444,8 @@ export const $didUpdateEdge = atom(false);
|
||||
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
|
||||
|
||||
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
|
||||
export const $isAddNodePopoverOpen = atom(false);
|
||||
export const closeAddNodePopover = () => {
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
$pendingConnection.set(null);
|
||||
};
|
||||
export const openAddNodePopover = () => {
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
};
|
||||
export const $addNodeCmdk = atom(false);
|
||||
export const useAddNodeCmdk = buildUseBoolean($addNodeCmdk);
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateNodesState = (state: any): any => {
|
||||
|
||||
Reference in New Issue
Block a user