mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support filtering on model format
This commit is contained in:
@@ -55,9 +55,12 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}, [data, fieldTemplate.ui_model_base, fieldTemplate.ui_model_type, fieldTemplate.ui_model_variant]);
|
||||
}, [data, fieldTemplate]);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
|
||||
@@ -2,20 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInvocationNodeContext } from 'features/nodes/components/flow/nodes/Invocation/context';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { isSingleOrCollection, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
const isConnectionInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
|
||||
);
|
||||
return (field.input === 'connection' && !isSingleOrCollection(field.type)) || !isStatefulFieldType(field.type);
|
||||
};
|
||||
|
||||
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
|
||||
field.type.name in TEMPLATE_BUILDER_MAP
|
||||
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && isStatefulFieldType(field.type)
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import type {
|
||||
SubModelType,
|
||||
T2IAdapterField,
|
||||
zClipVariantType,
|
||||
zModelFormat,
|
||||
zModelVariantType,
|
||||
} from 'features/nodes/types/common';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
@@ -43,6 +44,7 @@ describe('Common types', () => {
|
||||
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
|
||||
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
|
||||
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
|
||||
test('ModelFormat', () => assert<Equals<z.infer<typeof zModelFormat>, S['ModelFormat']>>());
|
||||
|
||||
// Misc types
|
||||
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
|
||||
|
||||
@@ -146,6 +146,22 @@ export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
|
||||
export const zClipVariantType = z.enum(['large', 'gigantic']);
|
||||
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
|
||||
export const zModelFormat = z.enum([
|
||||
'omi',
|
||||
'diffusers',
|
||||
'checkpoint',
|
||||
'lycoris',
|
||||
'onnx',
|
||||
'olive',
|
||||
'embedding_file',
|
||||
'embedding_folder',
|
||||
'invokeai',
|
||||
't5_encoder',
|
||||
'bnb_quantized_int8b',
|
||||
'bnb_quantized_nf4b',
|
||||
'gguf_quantized',
|
||||
'api',
|
||||
]);
|
||||
|
||||
export const zModelIdentifierField = z.object({
|
||||
key: z.string().min(1),
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
zClipVariantType,
|
||||
zColorField,
|
||||
zImageField,
|
||||
zModelFormat,
|
||||
zModelIdentifierField,
|
||||
zModelType,
|
||||
zModelVariantType,
|
||||
@@ -73,6 +74,7 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
|
||||
ui_model_base: z.array(zBaseModelType).nullish(),
|
||||
ui_model_type: z.array(zModelType).nullish(),
|
||||
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
|
||||
ui_model_format: z.array(zModelFormat).nullish(),
|
||||
});
|
||||
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
|
||||
@@ -449,7 +449,7 @@ const buildImageGeneratorFieldInputTemplate: FieldInputTemplateBuilder<ImageGene
|
||||
return template;
|
||||
};
|
||||
|
||||
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
|
||||
const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
|
||||
BoardField: buildBoardFieldInputTemplate,
|
||||
BooleanField: buildBooleanFieldInputTemplate,
|
||||
ColorField: buildColorFieldInputTemplate,
|
||||
@@ -464,7 +464,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
|
||||
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
|
||||
ImageGeneratorField: buildImageGeneratorFieldInputTemplate,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
fieldSchema: InvocationFieldSchema,
|
||||
@@ -482,6 +482,7 @@ export const buildFieldInputTemplate = (
|
||||
ui_model_base,
|
||||
ui_model_type,
|
||||
ui_model_variant,
|
||||
ui_model_format,
|
||||
} = fieldSchema;
|
||||
|
||||
// This is the base field template that is common to all fields. The builder function will add all other
|
||||
@@ -501,6 +502,7 @@ export const buildFieldInputTemplate = (
|
||||
ui_model_base,
|
||||
ui_model_type,
|
||||
ui_model_variant,
|
||||
ui_model_format,
|
||||
};
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
|
||||
Reference in New Issue
Block a user