feat(ui): support filtering on model format

This commit is contained in:
psychedelicious
2025-09-18 11:57:43 +10:00
parent bdeb9fb1cf
commit 035d9432bd
6 changed files with 31 additions and 10 deletions

View File

@@ -55,9 +55,12 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
) {
return false;
}
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
return false;
}
return true;
});
}, [data, fieldTemplate.ui_model_base, fieldTemplate.ui_model_type, fieldTemplate.ui_model_variant]);
}, [data, fieldTemplate]);
return (
<ModelFieldCombobox

View File

@@ -2,20 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useInvocationNodeContext } from 'features/nodes/components/flow/nodes/Invocation/context';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { isSingleOrCollection, isStatefulFieldType } from 'features/nodes/types/field';
import { useMemo } from 'react';
const isConnectionInputField = (field: FieldInputTemplate) => {
return (
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
);
return (field.input === 'connection' && !isSingleOrCollection(field.type)) || !isStatefulFieldType(field.type);
};
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
return (
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
field.type.name in TEMPLATE_BUILDER_MAP
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && isStatefulFieldType(field.type)
);
};

View File

@@ -13,6 +13,7 @@ import type {
SubModelType,
T2IAdapterField,
zClipVariantType,
zModelFormat,
zModelVariantType,
} from 'features/nodes/types/common';
import type { Invocation, S } from 'services/api/types';
@@ -43,6 +44,7 @@ describe('Common types', () => {
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
test('ModelFormat', () => assert<Equals<z.infer<typeof zModelFormat>, S['ModelFormat']>>());
// Misc types
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());

View File

@@ -146,6 +146,22 @@ export type SubModelType = z.infer<typeof zSubModelType>;
export const zClipVariantType = z.enum(['large', 'gigantic']);
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
export const zModelFormat = z.enum([
'omi',
'diffusers',
'checkpoint',
'lycoris',
'onnx',
'olive',
'embedding_file',
'embedding_folder',
'invokeai',
't5_encoder',
'bnb_quantized_int8b',
'bnb_quantized_nf4b',
'gguf_quantized',
'api',
]);
export const zModelIdentifierField = z.object({
key: z.string().min(1),

View File

@@ -15,6 +15,7 @@ import {
zClipVariantType,
zColorField,
zImageField,
zModelFormat,
zModelIdentifierField,
zModelType,
zModelVariantType,
@@ -73,6 +74,7 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
ui_model_base: z.array(zBaseModelType).nullish(),
ui_model_type: z.array(zModelType).nullish(),
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
ui_model_format: z.array(zModelFormat).nullish(),
});
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
fieldKind: z.literal('output'),

View File

@@ -449,7 +449,7 @@ const buildImageGeneratorFieldInputTemplate: FieldInputTemplateBuilder<ImageGene
return template;
};
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
BoardField: buildBoardFieldInputTemplate,
BooleanField: buildBooleanFieldInputTemplate,
ColorField: buildColorFieldInputTemplate,
@@ -464,7 +464,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
ImageGeneratorField: buildImageGeneratorFieldInputTemplate,
} as const;
};
export const buildFieldInputTemplate = (
fieldSchema: InvocationFieldSchema,
@@ -482,6 +482,7 @@ export const buildFieldInputTemplate = (
ui_model_base,
ui_model_type,
ui_model_variant,
ui_model_format,
} = fieldSchema;
// This is the base field template that is common to all fields. The builder function will add all other
@@ -501,6 +502,7 @@ export const buildFieldInputTemplate = (
ui_model_base,
ui_model_type,
ui_model_variant,
ui_model_format,
};
if (isStatefulFieldType(fieldType)) {