mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix(ui): model field types not recognized as such during workflow validation and field styling
This commit is contained in:
@@ -9,8 +9,8 @@ import {
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import { type FieldInputTemplate,isModelFieldType } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -64,7 +64,7 @@ export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
|
||||
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isConnectionInProgress = useIsConnectionInProgress();
|
||||
|
||||
if (isConnectionInProgress) {
|
||||
|
||||
@@ -9,8 +9,8 @@ import {
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||
import type { FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import { type FieldOutputTemplate,isModelFieldType } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -64,7 +64,7 @@ export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
|
||||
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isConnectionInProgress = useIsConnectionInProgress();
|
||||
|
||||
if (isConnectionInProgress) {
|
||||
|
||||
@@ -23,27 +23,6 @@ export const SHARED_NODE_PROPERTIES: Partial<AnyNode> = {
|
||||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model types' handles are rendered as squares in the UI.
|
||||
*/
|
||||
export const MODEL_TYPES = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'FluxMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'T2IAdapterModelField',
|
||||
'T5EncoderField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
* Colors for each field type - applies to their handles and edges.
|
||||
*/
|
||||
|
||||
@@ -278,6 +278,38 @@ export const isStatefulFieldType = (fieldType: FieldType): fieldType is Stateful
|
||||
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
|
||||
const modelFieldTypeNames = [
|
||||
// Stateful model fields
|
||||
zModelIdentifierFieldType.shape.name.value,
|
||||
zMainModelFieldType.shape.name.value,
|
||||
zSDXLMainModelFieldType.shape.name.value,
|
||||
zSD3MainModelFieldType.shape.name.value,
|
||||
zFluxMainModelFieldType.shape.name.value,
|
||||
zSDXLRefinerModelFieldType.shape.name.value,
|
||||
zVAEModelFieldType.shape.name.value,
|
||||
zLoRAModelFieldType.shape.name.value,
|
||||
zControlNetModelFieldType.shape.name.value,
|
||||
zIPAdapterModelFieldType.shape.name.value,
|
||||
zT2IAdapterModelFieldType.shape.name.value,
|
||||
zSpandrelImageToImageModelFieldType.shape.name.value,
|
||||
zT5EncoderModelFieldType.shape.name.value,
|
||||
zCLIPEmbedModelFieldType.shape.name.value,
|
||||
zCLIPLEmbedModelFieldType.shape.name.value,
|
||||
zCLIPGEmbedModelFieldType.shape.name.value,
|
||||
zControlLoRAModelFieldType.shape.name.value,
|
||||
zFluxVAEModelFieldType.shape.name.value,
|
||||
// Stateless model fields
|
||||
'UNetField',
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'T5EncoderField',
|
||||
'TransformerField',
|
||||
'ControlLoRAField',
|
||||
];
|
||||
export const isModelFieldType = (fieldType: FieldType) => {
|
||||
return (modelFieldTypeNames as string[]).includes(fieldType.name);
|
||||
};
|
||||
|
||||
export const isSingle = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.SINGLE;
|
||||
export const isCollection = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.COLLECTION;
|
||||
export const isSingleOrCollection = (fieldType: FieldType): boolean =>
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
isBoardFieldInputInstance,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldInputInstance,
|
||||
isModelFieldType,
|
||||
isModelIdentifierFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
@@ -25,21 +26,6 @@ type ValidateWorkflowResult = {
|
||||
warnings: WorkflowWarning[];
|
||||
};
|
||||
|
||||
const MODEL_FIELD_TYPES = [
|
||||
'ModelIdentifier',
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'FluxMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VAEModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'IPAdapterModelField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
* Parses and validates a workflow:
|
||||
* - Parses the workflow schema, and migrates it to the latest version if necessary.
|
||||
@@ -123,7 +109,7 @@ export const validateWorkflow = async (
|
||||
|
||||
// We need to confirm that all images, boards and models are accessible before loading,
|
||||
// else the workflow could end up with stale data an an error state.
|
||||
if (fieldTemplate.type.name === 'ImageField' && isImageFieldInputInstance(input) && input.value) {
|
||||
if (fieldTemplate.type.name === 'ImageField' && input.value && isImageFieldInputInstance(input)) {
|
||||
const hasAccess = await checkImageAccess(input.value.image_name);
|
||||
if (!hasAccess) {
|
||||
const message = t('nodes.imageAccessError', { image_name: input.value.image_name });
|
||||
@@ -131,7 +117,7 @@ export const validateWorkflow = async (
|
||||
input.value = undefined;
|
||||
}
|
||||
}
|
||||
if (fieldTemplate.type.name === 'ImageField' && isImageFieldCollectionInputInstance(input) && input.value) {
|
||||
if (fieldTemplate.type.name === 'ImageField' && input.value && isImageFieldCollectionInputInstance(input)) {
|
||||
for (const { image_name } of [...input.value]) {
|
||||
const hasAccess = await checkImageAccess(image_name);
|
||||
if (!hasAccess) {
|
||||
@@ -141,7 +127,7 @@ export const validateWorkflow = async (
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fieldTemplate.type.name === 'BoardField' && isBoardFieldInputInstance(input) && input.value) {
|
||||
if (fieldTemplate.type.name === 'BoardField' && input.value && isBoardFieldInputInstance(input)) {
|
||||
const hasAccess = await checkBoardAccess(input.value.board_id);
|
||||
if (!hasAccess) {
|
||||
const message = t('nodes.boardAccessError', { board_id: input.value.board_id });
|
||||
@@ -149,11 +135,7 @@ export const validateWorkflow = async (
|
||||
input.value = undefined;
|
||||
}
|
||||
}
|
||||
if (
|
||||
MODEL_FIELD_TYPES.includes(fieldTemplate.type.name) &&
|
||||
isModelIdentifierFieldInputInstance(input) &&
|
||||
input.value
|
||||
) {
|
||||
if (isModelFieldType(fieldTemplate.type) && input.value && isModelIdentifierFieldInputInstance(input)) {
|
||||
const hasAccess = await checkModelAccess(input.value.key);
|
||||
if (!hasAccess) {
|
||||
const message = t('nodes.modelAccessError', { key: input.value.key });
|
||||
|
||||
Reference in New Issue
Block a user