refactor(ui)refactor(ui): more cleanup of model categories

This commit is contained in:
psychedelicious
2025-09-18 18:15:39 +10:00
parent b68871a13f
commit bd893cf3f6
3 changed files with 79 additions and 67 deletions

View File

@@ -18,84 +18,117 @@ import {
isTIModelConfig,
isUnknownModelConfig,
isVAEModelConfig,
isVideoModelConfig,
} from 'services/api/types';
import { objectEntries } from 'tsafe';
type ModelCategoryData = {
import type { FilterableModelType } from './store/modelManagerV2Slice';
export type ModelCategoryData = {
category: FilterableModelType;
i18nKey: string;
filter: (config: AnyModelConfig) => boolean;
};
export const MODEL_CATEGORIES: Record<string, ModelCategoryData> = {
export const MODEL_CATEGORIES: Record<FilterableModelType, ModelCategoryData> = {
unknown: {
category: 'unknown',
i18nKey: 'common.unknown',
filter: isUnknownModelConfig,
},
main: {
i18nKey: 'model_manager.category.main_models',
category: 'main',
i18nKey: 'modelManager.main',
filter: isNonRefinerMainModelConfig,
},
refiner: {
i18nKey: 'model_manager.category.refiner_models',
category: 'refiner',
i18nKey: 'sdxl.refiner',
filter: isRefinerMainModelModelConfig,
},
lora: {
i18nKey: 'model_manager.category.lora_models',
category: 'lora',
i18nKey: 'modelManager.loraModels',
filter: isLoRAModelConfig,
},
embedding: {
i18nKey: 'model_manager.category.embedding_models',
category: 'embedding',
i18nKey: 'modelManager.textualInversions',
filter: isTIModelConfig,
},
controlnet: {
i18nKey: 'model_manager.category.controlnet_models',
category: 'controlnet',
i18nKey: 'ControlNet',
filter: isControlNetModelConfig,
},
t2i_adapter: {
i18nKey: 'model_manager.category.t2i_adapter_models',
category: 't2i_adapter',
i18nKey: 'common.t2iAdapter',
filter: isT2IAdapterModelConfig,
},
t5_encoder: {
i18nKey: 'model_manager.category.t5_encoder_models',
category: 't5_encoder',
i18nKey: 'modelManager.t5Encoder',
filter: isT5EncoderModelConfig,
},
control_lora: {
i18nKey: 'model_manager.category.control_lora_models',
category: 'control_lora',
i18nKey: 'modelManager.controlLora',
filter: isControlLoRAModelConfig,
},
clip_embed: {
i18nKey: 'model_manager.category.clip_embed_models',
category: 'clip_embed',
i18nKey: 'modelManager.clipEmbed',
filter: isCLIPEmbedModelConfig,
},
spandrel: {
i18nKey: 'model_manager.category.spandrel_image_to_image_models',
spandrel_image_to_image: {
category: 'spandrel_image_to_image',
i18nKey: 'modelManager.spandrelImageToImage',
filter: isSpandrelImageToImageModelConfig,
},
ip_adapter: {
i18nKey: 'model_manager.category.ip_adapter_models',
category: 'ip_adapter',
i18nKey: 'common.ipAdapter',
filter: isIPAdapterModelConfig,
},
vae: {
i18nKey: 'model_manager.category.vae_models',
category: 'vae',
i18nKey: 'VAE',
filter: isVAEModelConfig,
},
clip_vision: {
i18nKey: 'model_manager.category.clip_vision_models',
category: 'clip_vision',
i18nKey: 'CLIP Vision',
filter: isCLIPVisionModelConfig,
},
siglip: {
i18nKey: 'model_manager.category.siglip_models',
category: 'siglip',
i18nKey: 'modelManager.sigLip',
filter: isSigLipModelConfig,
},
flux_redux: {
i18nKey: 'model_manager.category.flux_redux_models',
category: 'flux_redux',
i18nKey: 'modelManager.fluxRedux',
filter: isFluxReduxModelConfig,
},
llava_one_vision: {
i18nKey: 'model_manager.category.llava_one_vision_models',
llava_onevision: {
category: 'llava_onevision',
i18nKey: 'modelManager.llavaOnevision',
filter: isLLaVAModelConfig,
},
unknown: {
i18nKey: 'model_manager.category.unknown_models',
filter: isUnknownModelConfig,
video: {
category: 'video',
i18nKey: 'Video',
filter: isVideoModelConfig,
},
};
export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({
category,
i18nKey,
filter,
}));
/**
* Mapping of model base to its color
*/

View File

@@ -2,7 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { MODEL_CATEGORIES } from 'features/modelManagerV2/models';
import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
import {
type FilterableModelType,
selectFilteredModelType,
@@ -31,7 +31,7 @@ const ModelList = () => {
const byCategory: { i18nKey: string; configs: AnyModelConfig[] }[] = [];
const total = baseFilteredModelConfigs.length;
let renderedTotal = 0;
for (const { i18nKey, filter } of Object.values(MODEL_CATEGORIES)) {
for (const { i18nKey, filter } of MODEL_CATEGORIES_AS_LIST) {
const configs = baseFilteredModelConfigs.filter(filter);
renderedTotal += configs.length;
byCategory.push({ i18nKey, configs });

View File

@@ -1,47 +1,17 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import type { ModelCategoryData } from 'features/modelManagerV2/models';
import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFunnelBold } from 'react-icons/pi';
import { objectKeys } from 'tsafe';
export const ModelTypeFilter = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
() => ({
main: t('modelManager.main'),
refiner: t('sdxl.refiner'),
lora: 'LoRA',
embedding: t('modelManager.textualInversions'),
controlnet: 'ControlNet',
vae: 'VAE',
t2i_adapter: t('common.t2iAdapter'),
t5_encoder: t('modelManager.t5Encoder'),
clip_embed: t('modelManager.clipEmbed'),
ip_adapter: t('common.ipAdapter'),
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
control_lora: t('modelManager.controlLora'),
siglip: t('modelManager.sigLip'),
flux_redux: t('modelManager.fluxRedux'),
llava_onevision: t('modelManager.llavaOnevision'),
video: t('modelManager.video'),
unknown: t('modelManager.unknown'),
}),
[t]
);
const filteredModelType = useAppSelector(selectFilteredModelType);
const selectModelType = useCallback(
(option: FilterableModelType) => {
dispatch(setFilteredModelType(option));
},
[dispatch]
);
const clearModelType = useCallback(() => {
dispatch(setFilteredModelType(null));
}, [dispatch]);
@@ -49,18 +19,12 @@ export const ModelTypeFilter = memo(() => {
return (
<Menu>
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
{filteredModelType ? t(MODEL_CATEGORIES[filteredModelType].i18nKey) : t('modelManager.allModels')}
</MenuButton>
<MenuList>
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
<MenuItem
key={option}
bg={filteredModelType === option ? 'base.700' : 'transparent'}
onClick={selectModelType.bind(null, option)}
>
{MODEL_TYPE_LABELS[option]}
</MenuItem>
{MODEL_CATEGORIES_AS_LIST.map((data) => (
<ModelMenuItem key={data.category} data={data} />
))}
</MenuList>
</Menu>
@@ -68,3 +32,18 @@ export const ModelTypeFilter = memo(() => {
});
ModelTypeFilter.displayName = 'ModelTypeFilter';
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const onClick = useCallback(() => {
dispatch(setFilteredModelType(data.category));
}, [data.category, dispatch]);
return (
<MenuItem bg={filteredModelType === data.category ? 'base.700' : 'transparent'} onClick={onClick}>
{t(data.i18nKey)}
</MenuItem>
);
});
ModelMenuItem.displayName = 'ModelMenuItem';