mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(ui)refactor(ui): more cleanup of model categories
This commit is contained in:
@@ -18,84 +18,117 @@ import {
|
||||
isTIModelConfig,
|
||||
isUnknownModelConfig,
|
||||
isVAEModelConfig,
|
||||
isVideoModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
type ModelCategoryData = {
|
||||
import type { FilterableModelType } from './store/modelManagerV2Slice';
|
||||
|
||||
export type ModelCategoryData = {
|
||||
category: FilterableModelType;
|
||||
i18nKey: string;
|
||||
filter: (config: AnyModelConfig) => boolean;
|
||||
};
|
||||
|
||||
export const MODEL_CATEGORIES: Record<string, ModelCategoryData> = {
|
||||
export const MODEL_CATEGORIES: Record<FilterableModelType, ModelCategoryData> = {
|
||||
unknown: {
|
||||
category: 'unknown',
|
||||
i18nKey: 'common.unknown',
|
||||
filter: isUnknownModelConfig,
|
||||
},
|
||||
main: {
|
||||
i18nKey: 'model_manager.category.main_models',
|
||||
category: 'main',
|
||||
i18nKey: 'modelManager.main',
|
||||
filter: isNonRefinerMainModelConfig,
|
||||
},
|
||||
refiner: {
|
||||
i18nKey: 'model_manager.category.refiner_models',
|
||||
category: 'refiner',
|
||||
i18nKey: 'sdxl.refiner',
|
||||
filter: isRefinerMainModelModelConfig,
|
||||
},
|
||||
lora: {
|
||||
i18nKey: 'model_manager.category.lora_models',
|
||||
category: 'lora',
|
||||
i18nKey: 'modelManager.loraModels',
|
||||
filter: isLoRAModelConfig,
|
||||
},
|
||||
embedding: {
|
||||
i18nKey: 'model_manager.category.embedding_models',
|
||||
category: 'embedding',
|
||||
i18nKey: 'modelManager.textualInversions',
|
||||
filter: isTIModelConfig,
|
||||
},
|
||||
controlnet: {
|
||||
i18nKey: 'model_manager.category.controlnet_models',
|
||||
category: 'controlnet',
|
||||
i18nKey: 'ControlNet',
|
||||
filter: isControlNetModelConfig,
|
||||
},
|
||||
t2i_adapter: {
|
||||
i18nKey: 'model_manager.category.t2i_adapter_models',
|
||||
category: 't2i_adapter',
|
||||
i18nKey: 'common.t2iAdapter',
|
||||
filter: isT2IAdapterModelConfig,
|
||||
},
|
||||
t5_encoder: {
|
||||
i18nKey: 'model_manager.category.t5_encoder_models',
|
||||
category: 't5_encoder',
|
||||
i18nKey: 'modelManager.t5Encoder',
|
||||
filter: isT5EncoderModelConfig,
|
||||
},
|
||||
control_lora: {
|
||||
i18nKey: 'model_manager.category.control_lora_models',
|
||||
category: 'control_lora',
|
||||
i18nKey: 'modelManager.controlLora',
|
||||
filter: isControlLoRAModelConfig,
|
||||
},
|
||||
clip_embed: {
|
||||
i18nKey: 'model_manager.category.clip_embed_models',
|
||||
category: 'clip_embed',
|
||||
i18nKey: 'modelManager.clipEmbed',
|
||||
filter: isCLIPEmbedModelConfig,
|
||||
},
|
||||
spandrel: {
|
||||
i18nKey: 'model_manager.category.spandrel_image_to_image_models',
|
||||
spandrel_image_to_image: {
|
||||
category: 'spandrel_image_to_image',
|
||||
i18nKey: 'modelManager.spandrelImageToImage',
|
||||
filter: isSpandrelImageToImageModelConfig,
|
||||
},
|
||||
ip_adapter: {
|
||||
i18nKey: 'model_manager.category.ip_adapter_models',
|
||||
category: 'ip_adapter',
|
||||
i18nKey: 'common.ipAdapter',
|
||||
filter: isIPAdapterModelConfig,
|
||||
},
|
||||
vae: {
|
||||
i18nKey: 'model_manager.category.vae_models',
|
||||
category: 'vae',
|
||||
i18nKey: 'VAE',
|
||||
filter: isVAEModelConfig,
|
||||
},
|
||||
clip_vision: {
|
||||
i18nKey: 'model_manager.category.clip_vision_models',
|
||||
category: 'clip_vision',
|
||||
i18nKey: 'CLIP Vision',
|
||||
filter: isCLIPVisionModelConfig,
|
||||
},
|
||||
siglip: {
|
||||
i18nKey: 'model_manager.category.siglip_models',
|
||||
category: 'siglip',
|
||||
i18nKey: 'modelManager.sigLip',
|
||||
filter: isSigLipModelConfig,
|
||||
},
|
||||
flux_redux: {
|
||||
i18nKey: 'model_manager.category.flux_redux_models',
|
||||
category: 'flux_redux',
|
||||
i18nKey: 'modelManager.fluxRedux',
|
||||
filter: isFluxReduxModelConfig,
|
||||
},
|
||||
llava_one_vision: {
|
||||
i18nKey: 'model_manager.category.llava_one_vision_models',
|
||||
llava_onevision: {
|
||||
category: 'llava_onevision',
|
||||
i18nKey: 'modelManager.llavaOnevision',
|
||||
filter: isLLaVAModelConfig,
|
||||
},
|
||||
unknown: {
|
||||
i18nKey: 'model_manager.category.unknown_models',
|
||||
filter: isUnknownModelConfig,
|
||||
video: {
|
||||
category: 'video',
|
||||
i18nKey: 'Video',
|
||||
filter: isVideoModelConfig,
|
||||
},
|
||||
};
|
||||
|
||||
export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({
|
||||
category,
|
||||
i18nKey,
|
||||
filter,
|
||||
}));
|
||||
|
||||
/**
|
||||
* Mapping of model base to its color
|
||||
*/
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { MODEL_CATEGORIES } from 'features/modelManagerV2/models';
|
||||
import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
|
||||
import {
|
||||
type FilterableModelType,
|
||||
selectFilteredModelType,
|
||||
@@ -31,7 +31,7 @@ const ModelList = () => {
|
||||
const byCategory: { i18nKey: string; configs: AnyModelConfig[] }[] = [];
|
||||
const total = baseFilteredModelConfigs.length;
|
||||
let renderedTotal = 0;
|
||||
for (const { i18nKey, filter } of Object.values(MODEL_CATEGORIES)) {
|
||||
for (const { i18nKey, filter } of MODEL_CATEGORIES_AS_LIST) {
|
||||
const configs = baseFilteredModelConfigs.filter(filter);
|
||||
renderedTotal += configs.length;
|
||||
byCategory.push({ i18nKey, configs });
|
||||
|
||||
@@ -1,47 +1,17 @@
|
||||
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import type { ModelCategoryData } from 'features/modelManagerV2/models';
|
||||
import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
|
||||
import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFunnelBold } from 'react-icons/pi';
|
||||
import { objectKeys } from 'tsafe';
|
||||
|
||||
export const ModelTypeFilter = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||
() => ({
|
||||
main: t('modelManager.main'),
|
||||
refiner: t('sdxl.refiner'),
|
||||
lora: 'LoRA',
|
||||
embedding: t('modelManager.textualInversions'),
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
t5_encoder: t('modelManager.t5Encoder'),
|
||||
clip_embed: t('modelManager.clipEmbed'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'CLIP Vision',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
control_lora: t('modelManager.controlLora'),
|
||||
siglip: t('modelManager.sigLip'),
|
||||
flux_redux: t('modelManager.fluxRedux'),
|
||||
llava_onevision: t('modelManager.llavaOnevision'),
|
||||
video: t('modelManager.video'),
|
||||
unknown: t('modelManager.unknown'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
|
||||
const selectModelType = useCallback(
|
||||
(option: FilterableModelType) => {
|
||||
dispatch(setFilteredModelType(option));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const clearModelType = useCallback(() => {
|
||||
dispatch(setFilteredModelType(null));
|
||||
}, [dispatch]);
|
||||
@@ -49,18 +19,12 @@ export const ModelTypeFilter = memo(() => {
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
|
||||
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
||||
{filteredModelType ? t(MODEL_CATEGORIES[filteredModelType].i18nKey) : t('modelManager.allModels')}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
||||
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
|
||||
<MenuItem
|
||||
key={option}
|
||||
bg={filteredModelType === option ? 'base.700' : 'transparent'}
|
||||
onClick={selectModelType.bind(null, option)}
|
||||
>
|
||||
{MODEL_TYPE_LABELS[option]}
|
||||
</MenuItem>
|
||||
{MODEL_CATEGORIES_AS_LIST.map((data) => (
|
||||
<ModelMenuItem key={data.category} data={data} />
|
||||
))}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
@@ -68,3 +32,18 @@ export const ModelTypeFilter = memo(() => {
|
||||
});
|
||||
|
||||
ModelTypeFilter.displayName = 'ModelTypeFilter';
|
||||
|
||||
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setFilteredModelType(data.category));
|
||||
}, [data.category, dispatch]);
|
||||
return (
|
||||
<MenuItem bg={filteredModelType === data.category ? 'base.700' : 'transparent'} onClick={onClick}>
|
||||
{t(data.i18nKey)}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
ModelMenuItem.displayName = 'ModelMenuItem';
|
||||
|
||||
Reference in New Issue
Block a user