perf(ui): optimized invocation node component structure

This commit is contained in:
psychedelicious
2025-02-15 20:51:49 +10:00
parent 8ecf9fb7e3
commit 01e4fd100f
5 changed files with 119 additions and 60 deletions

View File

@@ -1,9 +1,12 @@
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate';
import { OutputFieldNodesEditorView } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldNodesEditorView';
import { useInputFieldNamesByStatus } from 'features/nodes/hooks/useInputFieldNamesByStatus';
import {
useInputFieldNamesAnyOrDirect,
useInputFieldNamesConnection,
useInputFieldNamesMissing,
} from 'features/nodes/hooks/useInputFieldNamesByStatus';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
import { memo } from 'react';
@@ -15,19 +18,14 @@ import InvocationNodeHeader from './InvocationNodeHeader';
type Props = {
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
};
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const fieldNames = useInputFieldNamesByStatus(nodeId);
const InvocationNode = ({ nodeId, isOpen }: Props) => {
const withFooter = useWithFooter(nodeId);
const outputFieldNames = useOutputFieldNames(nodeId);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<InvocationNodeHeader nodeId={nodeId} isOpen={isOpen} label={label} selected={selected} type={type} />
<>
<InvocationNodeHeader nodeId={nodeId} isOpen={isOpen} />
{isOpen && (
<>
<Flex
@@ -41,38 +39,78 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
>
<Flex flexDir="column" px={2} w="full" h="full">
<Grid gridTemplateColumns="1fr auto" gridAutoRows="1fr">
{fieldNames.connectionFields.map((fieldName, i) => (
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
<InputFieldGate nodeId={nodeId} fieldName={fieldName}>
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
</InputFieldGate>
</GridItem>
))}
{outputFieldNames.map((fieldName, i) => (
<GridItem gridColumnStart={2} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.output-field`}>
<OutputFieldGate nodeId={nodeId} fieldName={fieldName}>
<OutputFieldNodesEditorView nodeId={nodeId} fieldName={fieldName} />
</OutputFieldGate>
</GridItem>
))}
<ConnectionFields nodeId={nodeId} />
<OutputFields nodeId={nodeId} />
</Grid>
{fieldNames.anyOrDirectFields.map((fieldName) => (
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
</InputFieldGate>
))}
{fieldNames.missingFields.map((fieldName) => (
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
</InputFieldGate>
))}
<AnyOrDirectFields nodeId={nodeId} />
<MissingFields nodeId={nodeId} />
</Flex>
</Flex>
{withFooter && <InvocationNodeFooter nodeId={nodeId} />}
</>
)}
</NodeWrapper>
</>
);
};
export default memo(InvocationNode);
const ConnectionFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useInputFieldNamesConnection(nodeId);
return (
<>
{fieldNames.map((fieldName, i) => (
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
<InputFieldGate nodeId={nodeId} fieldName={fieldName}>
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
</InputFieldGate>
</GridItem>
))}
</>
);
});
ConnectionFields.displayName = 'ConnectionFields';
const AnyOrDirectFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useInputFieldNamesAnyOrDirect(nodeId);
return (
<>
{fieldNames.map((fieldName) => (
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
</InputFieldGate>
))}
</>
);
});
AnyOrDirectFields.displayName = 'AnyOrDirectFields';
const MissingFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useInputFieldNamesMissing(nodeId);
return (
<>
{fieldNames.map((fieldName) => (
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
</InputFieldGate>
))}
</>
);
});
MissingFields.displayName = 'MissingFields';
const OutputFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useOutputFieldNames(nodeId);
return (
<>
{fieldNames.map((fieldName, i) => (
<GridItem gridColumnStart={2} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.output-field`}>
<OutputFieldGate nodeId={nodeId} fieldName={fieldName}>
<OutputFieldNodesEditorView nodeId={nodeId} fieldName={fieldName} />
</OutputFieldGate>
</GridItem>
))}
</>
);
});
OutputFields.displayName = 'OutputFields';

View File

@@ -11,9 +11,6 @@ import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator';
type Props = {
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
};
const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {

View File

@@ -1,6 +1,5 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import { useNodePack } from 'features/nodes/hooks/useNodePack';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { memo } from 'react';
@@ -11,14 +10,13 @@ type Props = {
isOpen: boolean;
label: string;
type: string;
selected: boolean;
};
const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type, selected }: Props) => {
const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type }: Props) => {
const { t } = useTranslation();
const nodePack = useNodePack(nodeId);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<>
<Flex
className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeHeader"
@@ -64,7 +62,7 @@ const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type, selected }
</Flex>
</Flex>
)}
</NodeWrapper>
</>
);
};

View File

@@ -2,6 +2,7 @@ import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import type { Node, NodeProps } from '@xyflow/react';
import { useAppSelector } from 'app/store/storeHooks';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectNodes } from 'features/nodes/store/selectors';
@@ -27,11 +28,17 @@ const InvocationNodeWrapper = (props: NodeProps<Node<InvocationNodeData>>) => {
if (!hasTemplate) {
return (
<InvocationNodeUnknownFallback nodeId={nodeId} isOpen={isOpen} label={label} type={type} selected={selected} />
<NodeWrapper nodeId={nodeId} selected={selected}>
<InvocationNodeUnknownFallback nodeId={nodeId} isOpen={isOpen} label={label} type={type} />
</NodeWrapper>
);
}
return <InvocationNode nodeId={nodeId} isOpen={isOpen} label={label} type={type} selected={selected} />;
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<InvocationNode nodeId={nodeId} isOpen={isOpen} />
</NodeWrapper>
);
};
export default memo(InvocationNodeWrapper);

View File

@@ -3,37 +3,56 @@ import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { difference, filter, keys } from 'lodash-es';
import { useMemo } from 'react';
const isConnectionInputField = (field: FieldInputTemplate) => {
return (
(field.input === 'connection' && !isSingleOrCollection(field.type)) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
);
};
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
return (
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
field.type.name in TEMPLATE_BUILDER_MAP
);
};
export const useInputFieldNamesByStatus = (nodeId: string) => {
export const useInputFieldNamesMissing = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const node = useNodeData(nodeId);
const fieldNames = useMemo(() => {
const instanceFields = keys(node.inputs);
const allTemplateFields = keys(template.inputs);
const missingFields = difference(instanceFields, allTemplateFields);
const connectionFields = filter(template.inputs, isConnectionInputField).map((f) => f.name);
const anyOrDirectFields = filter(template.inputs, isAnyOrDirectInputField).map((f) => f.name);
return {
missingFields,
connectionFields,
anyOrDirectFields,
};
const instanceFields = new Set(Object.keys(node.inputs));
const allTemplateFields = new Set(Object.keys(template.inputs));
return Array.from(instanceFields.difference(allTemplateFields));
}, [node.inputs, template.inputs]);
return fieldNames;
};
export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const fieldNames = useMemo(() => {
const anyOrDirectFields: string[] = [];
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
if (isAnyOrDirectInputField(fieldTemplate)) {
anyOrDirectFields.push(fieldName);
}
}
return anyOrDirectFields;
}, [template.inputs]);
return fieldNames;
};
export const useInputFieldNamesConnection = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const fieldNames = useMemo(() => {
const connectionFields: string[] = [];
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
if (isConnectionInputField(fieldTemplate)) {
connectionFields.push(fieldName);
}
}
return connectionFields;
}, [template.inputs]);
return fieldNames;
};