mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
perf(ui): optimized invocation node component structure
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
|
||||
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
|
||||
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
|
||||
import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate';
|
||||
import { OutputFieldNodesEditorView } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldNodesEditorView';
|
||||
import { useInputFieldNamesByStatus } from 'features/nodes/hooks/useInputFieldNamesByStatus';
|
||||
import {
|
||||
useInputFieldNamesAnyOrDirect,
|
||||
useInputFieldNamesConnection,
|
||||
useInputFieldNamesMissing,
|
||||
} from 'features/nodes/hooks/useInputFieldNamesByStatus';
|
||||
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
|
||||
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
|
||||
import { memo } from 'react';
|
||||
@@ -15,19 +18,14 @@ import InvocationNodeHeader from './InvocationNodeHeader';
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
isOpen: boolean;
|
||||
label: string;
|
||||
type: string;
|
||||
selected: boolean;
|
||||
};
|
||||
|
||||
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
const fieldNames = useInputFieldNamesByStatus(nodeId);
|
||||
const InvocationNode = ({ nodeId, isOpen }: Props) => {
|
||||
const withFooter = useWithFooter(nodeId);
|
||||
const outputFieldNames = useOutputFieldNames(nodeId);
|
||||
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<InvocationNodeHeader nodeId={nodeId} isOpen={isOpen} label={label} selected={selected} type={type} />
|
||||
<>
|
||||
<InvocationNodeHeader nodeId={nodeId} isOpen={isOpen} />
|
||||
{isOpen && (
|
||||
<>
|
||||
<Flex
|
||||
@@ -41,38 +39,78 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
>
|
||||
<Flex flexDir="column" px={2} w="full" h="full">
|
||||
<Grid gridTemplateColumns="1fr auto" gridAutoRows="1fr">
|
||||
{fieldNames.connectionFields.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
|
||||
<InputFieldGate nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldGate>
|
||||
</GridItem>
|
||||
))}
|
||||
{outputFieldNames.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={2} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.output-field`}>
|
||||
<OutputFieldGate nodeId={nodeId} fieldName={fieldName}>
|
||||
<OutputFieldNodesEditorView nodeId={nodeId} fieldName={fieldName} />
|
||||
</OutputFieldGate>
|
||||
</GridItem>
|
||||
))}
|
||||
<ConnectionFields nodeId={nodeId} />
|
||||
<OutputFields nodeId={nodeId} />
|
||||
</Grid>
|
||||
{fieldNames.anyOrDirectFields.map((fieldName) => (
|
||||
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldGate>
|
||||
))}
|
||||
{fieldNames.missingFields.map((fieldName) => (
|
||||
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldGate>
|
||||
))}
|
||||
<AnyOrDirectFields nodeId={nodeId} />
|
||||
<MissingFields nodeId={nodeId} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
{withFooter && <InvocationNodeFooter nodeId={nodeId} />}
|
||||
</>
|
||||
)}
|
||||
</NodeWrapper>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationNode);
|
||||
|
||||
const ConnectionFields = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const fieldNames = useInputFieldNamesConnection(nodeId);
|
||||
return (
|
||||
<>
|
||||
{fieldNames.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
|
||||
<InputFieldGate nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldGate>
|
||||
</GridItem>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
ConnectionFields.displayName = 'ConnectionFields';
|
||||
|
||||
const AnyOrDirectFields = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const fieldNames = useInputFieldNamesAnyOrDirect(nodeId);
|
||||
return (
|
||||
<>
|
||||
{fieldNames.map((fieldName) => (
|
||||
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldGate>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
AnyOrDirectFields.displayName = 'AnyOrDirectFields';
|
||||
|
||||
const MissingFields = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const fieldNames = useInputFieldNamesMissing(nodeId);
|
||||
return (
|
||||
<>
|
||||
{fieldNames.map((fieldName) => (
|
||||
<InputFieldGate key={`${nodeId}.${fieldName}.input-field`} nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputFieldEditModeNodes nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldGate>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
MissingFields.displayName = 'MissingFields';
|
||||
|
||||
const OutputFields = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const fieldNames = useOutputFieldNames(nodeId);
|
||||
return (
|
||||
<>
|
||||
{fieldNames.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={2} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.output-field`}>
|
||||
<OutputFieldGate nodeId={nodeId} fieldName={fieldName}>
|
||||
<OutputFieldNodesEditorView nodeId={nodeId} fieldName={fieldName} />
|
||||
</OutputFieldGate>
|
||||
</GridItem>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
OutputFields.displayName = 'OutputFields';
|
||||
|
||||
@@ -11,9 +11,6 @@ import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator';
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
isOpen: boolean;
|
||||
label: string;
|
||||
type: string;
|
||||
selected: boolean;
|
||||
};
|
||||
|
||||
const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
|
||||
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
|
||||
import { useNodePack } from 'features/nodes/hooks/useNodePack';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
@@ -11,14 +10,13 @@ type Props = {
|
||||
isOpen: boolean;
|
||||
label: string;
|
||||
type: string;
|
||||
selected: boolean;
|
||||
};
|
||||
|
||||
const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const nodePack = useNodePack(nodeId);
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<>
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
layerStyle="nodeHeader"
|
||||
@@ -64,7 +62,7 @@ const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type, selected }
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</NodeWrapper>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { Node, NodeProps } from '@xyflow/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
|
||||
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodes } from 'features/nodes/store/selectors';
|
||||
@@ -27,11 +28,17 @@ const InvocationNodeWrapper = (props: NodeProps<Node<InvocationNodeData>>) => {
|
||||
|
||||
if (!hasTemplate) {
|
||||
return (
|
||||
<InvocationNodeUnknownFallback nodeId={nodeId} isOpen={isOpen} label={label} type={type} selected={selected} />
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<InvocationNodeUnknownFallback nodeId={nodeId} isOpen={isOpen} label={label} type={type} />
|
||||
</NodeWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
return <InvocationNode nodeId={nodeId} isOpen={isOpen} label={label} type={type} selected={selected} />;
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<InvocationNode nodeId={nodeId} isOpen={isOpen} />
|
||||
</NodeWrapper>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationNodeWrapper);
|
||||
|
||||
@@ -3,37 +3,56 @@ import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { difference, filter, keys } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
const isConnectionInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
|
||||
);
|
||||
};
|
||||
|
||||
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
field.type.name in TEMPLATE_BUILDER_MAP
|
||||
);
|
||||
};
|
||||
|
||||
export const useInputFieldNamesByStatus = (nodeId: string) => {
|
||||
export const useInputFieldNamesMissing = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const node = useNodeData(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const instanceFields = keys(node.inputs);
|
||||
const allTemplateFields = keys(template.inputs);
|
||||
const missingFields = difference(instanceFields, allTemplateFields);
|
||||
const connectionFields = filter(template.inputs, isConnectionInputField).map((f) => f.name);
|
||||
const anyOrDirectFields = filter(template.inputs, isAnyOrDirectInputField).map((f) => f.name);
|
||||
return {
|
||||
missingFields,
|
||||
connectionFields,
|
||||
anyOrDirectFields,
|
||||
};
|
||||
const instanceFields = new Set(Object.keys(node.inputs));
|
||||
const allTemplateFields = new Set(Object.keys(template.inputs));
|
||||
return Array.from(instanceFields.difference(allTemplateFields));
|
||||
}, [node.inputs, template.inputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const anyOrDirectFields: string[] = [];
|
||||
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
|
||||
if (isAnyOrDirectInputField(fieldTemplate)) {
|
||||
anyOrDirectFields.push(fieldName);
|
||||
}
|
||||
}
|
||||
return anyOrDirectFields;
|
||||
}, [template.inputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
export const useInputFieldNamesConnection = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const connectionFields: string[] = [];
|
||||
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
|
||||
if (isConnectionInputField(fieldTemplate)) {
|
||||
connectionFields.push(fieldName);
|
||||
}
|
||||
}
|
||||
return connectionFields;
|
||||
}, [template.inputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user