fix(ui): model field types not recognized as such during workflow validation and field styling

This commit is contained in:
psychedelicious
2025-02-19 13:35:41 +10:00
parent 4bba7de070
commit d0a231d59e
5 changed files with 43 additions and 50 deletions

View File

@@ -9,8 +9,8 @@ import {
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { type FieldInputTemplate,isModelFieldType } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -64,7 +64,7 @@ export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);
const isConnectionInProgress = useIsConnectionInProgress();
if (isConnectionInProgress) {

View File

@@ -9,8 +9,8 @@ import {
} from 'features/nodes/hooks/useFieldConnectionState';
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldOutputTemplate } from 'features/nodes/types/field';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { type FieldOutputTemplate,isModelFieldType } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -64,7 +64,7 @@ export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);
const isConnectionInProgress = useIsConnectionInProgress();
if (isConnectionInProgress) {

View File

@@ -23,27 +23,6 @@ export const SHARED_NODE_PROPERTIES: Partial<AnyNode> = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
};
/**
* Model types' handles are rendered as squares in the UI.
*/
export const MODEL_TYPES = [
'IPAdapterModelField',
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'FluxMainModelField',
'SD3MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'VAEField',
'CLIPField',
'T2IAdapterModelField',
'T5EncoderField',
'SpandrelImageToImageModelField',
];
/**
* Colors for each field type - applies to their handles and edges.
*/

View File

@@ -278,6 +278,38 @@ export const isStatefulFieldType = (fieldType: FieldType): fieldType is Stateful
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
export type FieldType = z.infer<typeof zFieldType>;
const modelFieldTypeNames = [
// Stateful model fields
zModelIdentifierFieldType.shape.name.value,
zMainModelFieldType.shape.name.value,
zSDXLMainModelFieldType.shape.name.value,
zSD3MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
zControlNetModelFieldType.shape.name.value,
zIPAdapterModelFieldType.shape.name.value,
zT2IAdapterModelFieldType.shape.name.value,
zSpandrelImageToImageModelFieldType.shape.name.value,
zT5EncoderModelFieldType.shape.name.value,
zCLIPEmbedModelFieldType.shape.name.value,
zCLIPLEmbedModelFieldType.shape.name.value,
zCLIPGEmbedModelFieldType.shape.name.value,
zControlLoRAModelFieldType.shape.name.value,
zFluxVAEModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
'VAEField',
'CLIPField',
'T5EncoderField',
'TransformerField',
'ControlLoRAField',
];
export const isModelFieldType = (fieldType: FieldType) => {
return (modelFieldTypeNames as string[]).includes(fieldType.name);
};
export const isSingle = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.SINGLE;
export const isCollection = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.COLLECTION;
export const isSingleOrCollection = (fieldType: FieldType): boolean =>

View File

@@ -4,6 +4,7 @@ import {
isBoardFieldInputInstance,
isImageFieldCollectionInputInstance,
isImageFieldInputInstance,
isModelFieldType,
isModelIdentifierFieldInputInstance,
} from 'features/nodes/types/field';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
@@ -25,21 +26,6 @@ type ValidateWorkflowResult = {
warnings: WorkflowWarning[];
};
const MODEL_FIELD_TYPES = [
'ModelIdentifier',
'MainModelField',
'SDXLMainModelField',
'FluxMainModelField',
'SD3MainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',
'ControlNetModelField',
'IPAdapterModelField',
'T2IAdapterModelField',
'SpandrelImageToImageModelField',
];
/**
* Parses and validates a workflow:
* - Parses the workflow schema, and migrates it to the latest version if necessary.
@@ -123,7 +109,7 @@ export const validateWorkflow = async (
// We need to confirm that all images, boards and models are accessible before loading,
// else the workflow could end up with stale data an an error state.
if (fieldTemplate.type.name === 'ImageField' && isImageFieldInputInstance(input) && input.value) {
if (fieldTemplate.type.name === 'ImageField' && input.value && isImageFieldInputInstance(input)) {
const hasAccess = await checkImageAccess(input.value.image_name);
if (!hasAccess) {
const message = t('nodes.imageAccessError', { image_name: input.value.image_name });
@@ -131,7 +117,7 @@ export const validateWorkflow = async (
input.value = undefined;
}
}
if (fieldTemplate.type.name === 'ImageField' && isImageFieldCollectionInputInstance(input) && input.value) {
if (fieldTemplate.type.name === 'ImageField' && input.value && isImageFieldCollectionInputInstance(input)) {
for (const { image_name } of [...input.value]) {
const hasAccess = await checkImageAccess(image_name);
if (!hasAccess) {
@@ -141,7 +127,7 @@ export const validateWorkflow = async (
}
}
}
if (fieldTemplate.type.name === 'BoardField' && isBoardFieldInputInstance(input) && input.value) {
if (fieldTemplate.type.name === 'BoardField' && input.value && isBoardFieldInputInstance(input)) {
const hasAccess = await checkBoardAccess(input.value.board_id);
if (!hasAccess) {
const message = t('nodes.boardAccessError', { board_id: input.value.board_id });
@@ -149,11 +135,7 @@ export const validateWorkflow = async (
input.value = undefined;
}
}
if (
MODEL_FIELD_TYPES.includes(fieldTemplate.type.name) &&
isModelIdentifierFieldInputInstance(input) &&
input.value
) {
if (isModelFieldType(fieldTemplate.type) && input.value && isModelIdentifierFieldInputInstance(input)) {
const hasAccess = await checkModelAccess(input.value.key);
if (!hasAccess) {
const message = t('nodes.modelAccessError', { key: input.value.key });