diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index e3e817a424..76af8fa8db 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -18,84 +18,117 @@ import { isTIModelConfig, isUnknownModelConfig, isVAEModelConfig, + isVideoModelConfig, } from 'services/api/types'; +import { objectEntries } from 'tsafe'; -type ModelCategoryData = { +import type { FilterableModelType } from './store/modelManagerV2Slice'; + +export type ModelCategoryData = { + category: FilterableModelType; i18nKey: string; filter: (config: AnyModelConfig) => boolean; }; -export const MODEL_CATEGORIES: Record = { +export const MODEL_CATEGORIES: Record = { + unknown: { + category: 'unknown', + i18nKey: 'common.unknown', + filter: isUnknownModelConfig, + }, main: { - i18nKey: 'model_manager.category.main_models', + category: 'main', + i18nKey: 'modelManager.main', filter: isNonRefinerMainModelConfig, }, refiner: { - i18nKey: 'model_manager.category.refiner_models', + category: 'refiner', + i18nKey: 'sdxl.refiner', filter: isRefinerMainModelModelConfig, }, lora: { - i18nKey: 'model_manager.category.lora_models', + category: 'lora', + i18nKey: 'modelManager.loraModels', filter: isLoRAModelConfig, }, embedding: { - i18nKey: 'model_manager.category.embedding_models', + category: 'embedding', + i18nKey: 'modelManager.textualInversions', filter: isTIModelConfig, }, controlnet: { - i18nKey: 'model_manager.category.controlnet_models', + category: 'controlnet', + i18nKey: 'ControlNet', filter: isControlNetModelConfig, }, t2i_adapter: { - i18nKey: 'model_manager.category.t2i_adapter_models', + category: 't2i_adapter', + i18nKey: 'common.t2iAdapter', filter: isT2IAdapterModelConfig, }, t5_encoder: { - i18nKey: 'model_manager.category.t5_encoder_models', + category: 't5_encoder', + i18nKey: 'modelManager.t5Encoder', filter: isT5EncoderModelConfig, }, control_lora: { - i18nKey: 'model_manager.category.control_lora_models', + category: 'control_lora', + i18nKey: 'modelManager.controlLora', filter: isControlLoRAModelConfig, }, clip_embed: { - i18nKey: 'model_manager.category.clip_embed_models', + category: 'clip_embed', + i18nKey: 'modelManager.clipEmbed', filter: isCLIPEmbedModelConfig, }, - spandrel: { - i18nKey: 'model_manager.category.spandrel_image_to_image_models', + spandrel_image_to_image: { + category: 'spandrel_image_to_image', + i18nKey: 'modelManager.spandrelImageToImage', filter: isSpandrelImageToImageModelConfig, }, ip_adapter: { - i18nKey: 'model_manager.category.ip_adapter_models', + category: 'ip_adapter', + i18nKey: 'common.ipAdapter', filter: isIPAdapterModelConfig, }, vae: { - i18nKey: 'model_manager.category.vae_models', + category: 'vae', + i18nKey: 'VAE', filter: isVAEModelConfig, }, clip_vision: { - i18nKey: 'model_manager.category.clip_vision_models', + category: 'clip_vision', + i18nKey: 'CLIP Vision', filter: isCLIPVisionModelConfig, }, siglip: { - i18nKey: 'model_manager.category.siglip_models', + category: 'siglip', + i18nKey: 'modelManager.sigLip', filter: isSigLipModelConfig, }, flux_redux: { - i18nKey: 'model_manager.category.flux_redux_models', + category: 'flux_redux', + i18nKey: 'modelManager.fluxRedux', filter: isFluxReduxModelConfig, }, - llava_one_vision: { - i18nKey: 'model_manager.category.llava_one_vision_models', + llava_onevision: { + category: 'llava_onevision', + i18nKey: 'modelManager.llavaOnevision', filter: isLLaVAModelConfig, }, - unknown: { - i18nKey: 'model_manager.category.unknown_models', - filter: isUnknownModelConfig, + video: { + category: 'video', + i18nKey: 'Video', + filter: isVideoModelConfig, }, }; +export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({ + category, + i18nKey, + filter, +})); + /** * Mapping of model base to its color */ diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 42f957391d..bde3f1d594 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -2,7 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { useAppSelector } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { MODEL_CATEGORIES } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; import { type FilterableModelType, selectFilteredModelType, @@ -31,7 +31,7 @@ const ModelList = () => { const byCategory: { i18nKey: string; configs: AnyModelConfig[] }[] = []; const total = baseFilteredModelConfigs.length; let renderedTotal = 0; - for (const { i18nKey, filter } of Object.values(MODEL_CATEGORIES)) { + for (const { i18nKey, filter } of MODEL_CATEGORIES_AS_LIST) { const configs = baseFilteredModelConfigs.filter(filter); renderedTotal += configs.length; byCategory.push({ i18nKey, configs }); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx index d033c2b562..0ee479e86b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -1,47 +1,17 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; +import type { ModelCategoryData } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiFunnelBold } from 'react-icons/pi'; -import { objectKeys } from 'tsafe'; export const ModelTypeFilter = memo(() => { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const MODEL_TYPE_LABELS: Record = useMemo( - () => ({ - main: t('modelManager.main'), - refiner: t('sdxl.refiner'), - lora: 'LoRA', - embedding: t('modelManager.textualInversions'), - controlnet: 'ControlNet', - vae: 'VAE', - t2i_adapter: t('common.t2iAdapter'), - t5_encoder: t('modelManager.t5Encoder'), - clip_embed: t('modelManager.clipEmbed'), - ip_adapter: t('common.ipAdapter'), - clip_vision: 'CLIP Vision', - spandrel_image_to_image: t('modelManager.spandrelImageToImage'), - control_lora: t('modelManager.controlLora'), - siglip: t('modelManager.sigLip'), - flux_redux: t('modelManager.fluxRedux'), - llava_onevision: t('modelManager.llavaOnevision'), - video: t('modelManager.video'), - unknown: t('modelManager.unknown'), - }), - [t] - ); const filteredModelType = useAppSelector(selectFilteredModelType); - const selectModelType = useCallback( - (option: FilterableModelType) => { - dispatch(setFilteredModelType(option)); - }, - [dispatch] - ); - const clearModelType = useCallback(() => { dispatch(setFilteredModelType(null)); }, [dispatch]); @@ -49,18 +19,12 @@ export const ModelTypeFilter = memo(() => { return ( }> - {filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')} + {filteredModelType ? t(MODEL_CATEGORIES[filteredModelType].i18nKey) : t('modelManager.allModels')} {t('modelManager.allModels')} - {objectKeys(MODEL_TYPE_LABELS).map((option) => ( - - {MODEL_TYPE_LABELS[option]} - + {MODEL_CATEGORIES_AS_LIST.map((data) => ( + ))} @@ -68,3 +32,18 @@ export const ModelTypeFilter = memo(() => { }); ModelTypeFilter.displayName = 'ModelTypeFilter'; + +const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const onClick = useCallback(() => { + dispatch(setFilteredModelType(data.category)); + }, [data.category, dispatch]); + return ( + + {t(data.i18nKey)} + + ); +}); +ModelMenuItem.displayName = 'ModelMenuItem';