From 6e9e8d6bd2f1323eb93f1d6d30abb9aaffcf3fb5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:47:31 +1000 Subject: [PATCH] feat(ui): allow changing model type in MM, fix up base and variant selects --- .../web/src/features/modelManagerV2/models.ts | 32 ++++++++++++++++++- .../ModelPanel/Fields/BaseModelSelect.tsx | 13 +++----- .../ModelPanel/Fields/ModelTypeSelect.tsx | 32 +++++++++++++++++++ .../ModelPanel/Fields/ModelVariantSelect.tsx | 8 ++--- .../subpanels/ModelPanel/ModelEdit.tsx | 5 +++ .../web/src/features/nodes/types/common.ts | 2 ++ 6 files changed, 78 insertions(+), 14 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 2c0acb1d4d..798ed1ac0e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -1,4 +1,4 @@ -import type { BaseModelType } from 'features/nodes/types/common'; +import type { BaseModelType, ModelType, ModelVariantType } from 'features/nodes/types/common'; import type { AnyModelConfig } from 'services/api/types'; import { isCLIPEmbedModelConfig, @@ -151,6 +151,30 @@ export const MODEL_BASE_TO_COLOR: Record = { unknown: 'red', }; +/** + * Mapping of model type to human readable name + */ +export const MODEL_TYPE_TO_LONG_NAME: Record = { + main: 'Main', + vae: 'VAE', + lora: 'LoRA', + llava_onevision: 'LLaVA OneVision', + control_lora: 'ControlLoRA', + controlnet: 'ControlNet', + t2i_adapter: 'T2I Adapter', + ip_adapter: 'IP Adapter', + embedding: 'Embedding', + onnx: 'ONNX', + clip_vision: 'CLIP Vision', + spandrel_image_to_image: 'Spandrel (Image to Image)', + t5_encoder: 'T5 Encoder', + clip_embed: 'CLIP Embed', + siglip: 'SigLIP', + flux_redux: 'FLUX Redux', + video: 'Video', + unknown: 'Unknown', +}; + /** * Mapping of model base to human readable name */ @@ -195,6 +219,12 @@ export const MODEL_BASE_TO_SHORT_NAME: Record = { unknown: 'Unknown', }; +export const MODEL_VARIANT_TO_LONG_NAME: Record = { + normal: 'Normal', + inpaint: 'Inpaint', + depth: 'Depth', +}; + /** * List of base models that make API requests */ diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx index fc0c8403e1..8235d26efe 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx @@ -6,15 +6,12 @@ import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; -const options: ComboboxOption[] = [ - { value: 'sd-1', label: MODEL_BASE_TO_LONG_NAME['sd-1'] }, - { value: 'sd-2', label: MODEL_BASE_TO_LONG_NAME['sd-2'] }, - { value: 'sd-3', label: MODEL_BASE_TO_LONG_NAME['sd-3'] }, - { value: 'flux', label: MODEL_BASE_TO_LONG_NAME['flux'] }, - { value: 'sdxl', label: MODEL_BASE_TO_LONG_NAME['sdxl'] }, - { value: 'sdxl-refiner', label: MODEL_BASE_TO_LONG_NAME['sdxl-refiner'] }, -]; +const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({ + label, + value, +})); type Props = { control: Control; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx new file mode 100644 index 0000000000..44b41f0151 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx @@ -0,0 +1,32 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models'; +import { useCallback, useMemo } from 'react'; +import type { Control } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; + +const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({ + label, + value, +})); + +type Props = { + control: Control; +}; + +const ModelTypeSelect = ({ control }: Props) => { + const { field } = useController({ control, name: 'type' }); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(ModelTypeSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx index 6686cc4336..52eb2a4749 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx @@ -1,16 +1,14 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library'; import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; -const options: ComboboxOption[] = [ - { value: 'normal', label: 'Normal' }, - { value: 'inpaint', label: 'Inpaint' }, - { value: 'depth', label: 'Depth' }, -]; +const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value })); type Props = { control: Control; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx index ff5c680325..12f591f9a3 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx @@ -22,6 +22,7 @@ import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoi import type { AnyModelConfig } from 'services/api/types'; import BaseModelSelect from './Fields/BaseModelSelect'; +import ModelTypeSelect from './Fields/ModelTypeSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import { ModelFooter } from './ModelFooter'; @@ -127,6 +128,10 @@ export const ModelEdit = memo(({ modelConfig }: Props) => { )} + + {t('modelManager.modelType')} + + {modelConfig.type !== 'clip_vision' && ( {t('modelManager.baseModel')} diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index d89ac6c823..8bab74f2d2 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -129,6 +129,7 @@ export const zModelType = z.enum([ 'video', 'unknown', ]); +export type ModelType = z.infer; const zSubModelType = z.enum([ 'unet', 'transformer', @@ -148,6 +149,7 @@ export type SubModelType = z.infer; export const zClipVariantType = z.enum(['large', 'gigantic']); export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']); +export type ModelVariantType = z.infer; export const zModelFormat = z.enum([ 'omi', 'diffusers',