feat(ui): allow changing model type in MM, fix up base and variant selects

This commit is contained in:
psychedelicious
2025-09-18 18:47:31 +10:00
parent eb6b3b8168
commit 6e9e8d6bd2
6 changed files with 78 additions and 14 deletions

View File

@@ -1,4 +1,4 @@
import type { BaseModelType } from 'features/nodes/types/common';
import type { BaseModelType, ModelType, ModelVariantType } from 'features/nodes/types/common';
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
@@ -151,6 +151,30 @@ export const MODEL_BASE_TO_COLOR: Record<BaseModelType, string> = {
unknown: 'red',
};
/**
* Mapping of model type to human readable name
*/
export const MODEL_TYPE_TO_LONG_NAME: Record<ModelType, string> = {
main: 'Main',
vae: 'VAE',
lora: 'LoRA',
llava_onevision: 'LLaVA OneVision',
control_lora: 'ControlLoRA',
controlnet: 'ControlNet',
t2i_adapter: 'T2I Adapter',
ip_adapter: 'IP Adapter',
embedding: 'Embedding',
onnx: 'ONNX',
clip_vision: 'CLIP Vision',
spandrel_image_to_image: 'Spandrel (Image to Image)',
t5_encoder: 'T5 Encoder',
clip_embed: 'CLIP Embed',
siglip: 'SigLIP',
flux_redux: 'FLUX Redux',
video: 'Video',
unknown: 'Unknown',
};
/**
* Mapping of model base to human readable name
*/
@@ -195,6 +219,12 @@ export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
unknown: 'Unknown',
};
export const MODEL_VARIANT_TO_LONG_NAME: Record<ModelVariantType, string> = {
normal: 'Normal',
inpaint: 'Inpaint',
depth: 'Depth',
};
/**
* List of base models that make API requests
*/

View File

@@ -6,15 +6,12 @@ import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_BASE_TO_LONG_NAME['sd-1'] },
{ value: 'sd-2', label: MODEL_BASE_TO_LONG_NAME['sd-2'] },
{ value: 'sd-3', label: MODEL_BASE_TO_LONG_NAME['sd-3'] },
{ value: 'flux', label: MODEL_BASE_TO_LONG_NAME['flux'] },
{ value: 'sdxl', label: MODEL_BASE_TO_LONG_NAME['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_BASE_TO_LONG_NAME['sdxl-refiner'] },
];
const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({
label,
value,
}));
type Props = {
control: Control<UpdateModelArg['body']>;

View File

@@ -0,0 +1,32 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({
label,
value,
}));
type Props = {
control: Control<UpdateModelArg['body']>;
};
const ModelTypeSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'type' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelTypeSelect);

View File

@@ -1,16 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value }));
type Props = {
control: Control<UpdateModelArg['body']>;

View File

@@ -22,6 +22,7 @@ import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoi
import type { AnyModelConfig } from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect';
import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import { ModelFooter } from './ModelFooter';
@@ -127,6 +128,10 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
</Heading>
)}
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect control={form.control} />
</FormControl>
{modelConfig.type !== 'clip_vision' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>

View File

@@ -129,6 +129,7 @@ export const zModelType = z.enum([
'video',
'unknown',
]);
export type ModelType = z.infer<typeof zModelType>;
const zSubModelType = z.enum([
'unet',
'transformer',
@@ -148,6 +149,7 @@ export type SubModelType = z.infer<typeof zSubModelType>;
export const zClipVariantType = z.enum(['large', 'gigantic']);
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
export type ModelVariantType = z.infer<typeof zModelVariantType>;
export const zModelFormat = z.enum([
'omi',
'diffusers',