mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 15:04:59 -05:00
simplifies Modelpicker wrapper
This commit is contained in:
committed by
psychedelicious
parent
29e87fc615
commit
f35b05be43
@@ -28,10 +28,13 @@ import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { filesize } from 'filesize';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
import { PiCaretDownBold, PiStarFill } from 'react-icons/pi';
|
||||
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
|
||||
|
||||
const getOptionId = (modelConfig: AnyModelConfig) => modelConfig.key;
|
||||
// Type for models with starred field
|
||||
type StarredModelConfig = AnyModelConfig & { starred?: boolean };
|
||||
|
||||
const getOptionId = <T extends AnyModelConfig>(modelConfig: T & { starred?: boolean }) => modelConfig.key;
|
||||
|
||||
const ModelManagerLink = memo((props: ButtonProps) => {
|
||||
const onClickGoToModelManager = useStore($onClickGoToModelManager);
|
||||
@@ -143,41 +146,30 @@ export const ModelPicker = typedMemo(
|
||||
initialGroupStates?: Record<string, boolean>;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const options = useMemo<T[] | Group<T>[]>(() => {
|
||||
const options = useMemo<(T & { starred?: boolean })[] | Group<T & { starred?: boolean }>[]>(() => {
|
||||
if (!grouped) {
|
||||
// Handle related models for non-grouped view
|
||||
if (relatedModelKeys.length > 0) {
|
||||
const relatedModels: T[] = [];
|
||||
const otherModels: T[] = [];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
if (relatedModelKeys.includes(modelConfig.key)) {
|
||||
relatedModels.push(modelConfig);
|
||||
} else {
|
||||
otherModels.push(modelConfig);
|
||||
}
|
||||
}
|
||||
|
||||
return [...relatedModels, ...otherModels];
|
||||
}
|
||||
return modelConfigs;
|
||||
// Add starred field to model options and sort them
|
||||
const modelsWithStarred = modelConfigs.map(model => ({
|
||||
...model,
|
||||
starred: relatedModelKeys.includes(model.key)
|
||||
}));
|
||||
|
||||
// Sort so starred models come first
|
||||
return modelsWithStarred.sort((a, b) => {
|
||||
if (a.starred && !b.starred) return -1;
|
||||
if (!a.starred && b.starred) return 1;
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
// When all groups are disabled, we show all models
|
||||
const groups: Record<string, Group<T>> = {};
|
||||
const relatedModels: T[] = [];
|
||||
const groups: Record<string, Group<T & { starred?: boolean }>> = {};
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
// Check if this model is related and separate it
|
||||
if (relatedModelKeys.length > 0 && relatedModelKeys.includes(modelConfig.key)) {
|
||||
relatedModels.push(modelConfig);
|
||||
continue;
|
||||
}
|
||||
|
||||
const groupId = getGroupIDFromModelConfig(modelConfig);
|
||||
let group = groups[groupId];
|
||||
if (!group) {
|
||||
group = buildGroup<T>({
|
||||
group = buildGroup<T & { starred?: boolean }>({
|
||||
id: modelConfig.base,
|
||||
color: `${getGroupColorSchemeFromModelConfig(modelConfig)}.300`,
|
||||
shortName: getGroupShortNameFromModelConfig(modelConfig),
|
||||
@@ -188,29 +180,27 @@ export const ModelPicker = typedMemo(
|
||||
groups[groupId] = group;
|
||||
}
|
||||
if (group) {
|
||||
group.options.push(modelConfig);
|
||||
// Add starred field to the model
|
||||
const modelWithStarred = {
|
||||
...modelConfig,
|
||||
starred: relatedModelKeys.includes(modelConfig.key)
|
||||
};
|
||||
group.options.push(modelWithStarred);
|
||||
}
|
||||
}
|
||||
|
||||
const _options: Group<T>[] = [];
|
||||
const _options: Group<T & { starred?: boolean }>[] = [];
|
||||
|
||||
// Add related models group first if there are any
|
||||
if (relatedModels.length > 0) {
|
||||
const relatedGroup = buildGroup<T>({
|
||||
id: 'related',
|
||||
color: 'accent.300',
|
||||
shortName: t('modelManager.showOnlyRelatedModels'),
|
||||
name: t('modelManager.relatedModels'),
|
||||
getOptionCountString: (count) => t('common.model_withCount', { count }),
|
||||
options: relatedModels,
|
||||
});
|
||||
_options.push(relatedGroup);
|
||||
}
|
||||
|
||||
// Add other groups in the original order
|
||||
// Add groups in the original order
|
||||
for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) {
|
||||
const group = groups[groupId];
|
||||
if (group) {
|
||||
// Sort options within each group so starred ones come first
|
||||
group.options.sort((a, b) => {
|
||||
if (a.starred && !b.starred) return -1;
|
||||
if (!a.starred && b.starred) return 1;
|
||||
return 0;
|
||||
});
|
||||
_options.push(group);
|
||||
delete groups[groupId];
|
||||
}
|
||||
@@ -220,7 +210,7 @@ export const ModelPicker = typedMemo(
|
||||
return _options;
|
||||
}, [grouped, modelConfigs, relatedModelKeys, t]);
|
||||
const popover = useDisclosure(false);
|
||||
const pickerRef = useRef<PickerContextState<T>>(null);
|
||||
const pickerRef = useRef<PickerContextState<T & { starred?: boolean }>>(null);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
popover.close();
|
||||
@@ -228,9 +218,11 @@ export const ModelPicker = typedMemo(
|
||||
}, [popover]);
|
||||
|
||||
const onSelect = useCallback(
|
||||
(model: T) => {
|
||||
(model: T & { starred?: boolean }) => {
|
||||
onClose();
|
||||
onChange(model);
|
||||
// Remove the starred field before passing to onChange
|
||||
const { starred, ...modelWithoutStarred } = model;
|
||||
onChange(modelWithoutStarred as T);
|
||||
},
|
||||
[onChange, onClose]
|
||||
);
|
||||
@@ -245,14 +237,6 @@ export const ModelPicker = typedMemo(
|
||||
return undefined;
|
||||
}, [allowEmpty, isInvalid, selectedModelConfig]);
|
||||
|
||||
// Create a component wrapper that includes related model styling
|
||||
const RelatedModelPickerOptionComponent = useCallback(
|
||||
({ option, ...rest }: { option: T } & BoxProps) => (
|
||||
<PickerOptionComponent option={option} isRelated={relatedModelKeys.includes(option.key)} {...rest} />
|
||||
),
|
||||
[relatedModelKeys]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
isOpen={popover.isOpen}
|
||||
@@ -279,14 +263,14 @@ export const ModelPicker = typedMemo(
|
||||
<PopoverContent p={0} w={400} h={400}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody p={0} w="full" h="full">
|
||||
<Picker<T>
|
||||
<Picker<T & { starred?: boolean }>
|
||||
handleRef={pickerRef}
|
||||
optionsOrGroups={options}
|
||||
getOptionId={getOptionId}
|
||||
getOptionId={getOptionId<T>}
|
||||
onSelect={onSelect}
|
||||
selectedOption={selectedModelConfig}
|
||||
isMatch={isMatch}
|
||||
OptionComponent={RelatedModelPickerOptionComponent}
|
||||
selectedOption={selectedModelConfig ? { ...selectedModelConfig, starred: relatedModelKeys.includes(selectedModelConfig.key) } : undefined}
|
||||
isMatch={isMatch<T>}
|
||||
OptionComponent={PickerOptionComponent<T>}
|
||||
noOptionsFallback={<NoOptionsFallback noOptionsText={noOptionsText} />}
|
||||
noMatchesFallback={t('modelManager.noMatchingModels')}
|
||||
NextToSearchBar={<NavigateToModelManagerButton />}
|
||||
@@ -338,19 +322,20 @@ const optionNameSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const PickerOptionComponent = typedMemo(
|
||||
({ option, isRelated = false, ...rest }: { option: AnyModelConfig; isRelated?: boolean } & BoxProps) => {
|
||||
const { $compactView } = usePickerContext<AnyModelConfig>();
|
||||
<T extends AnyModelConfig>({ option, ...rest }: { option: T & { starred?: boolean } } & BoxProps) => {
|
||||
const { $compactView } = usePickerContext<T & { starred?: boolean }>();
|
||||
const compactView = useStore($compactView);
|
||||
|
||||
const displayName = isRelated ? `* ${option.name}` : option.name;
|
||||
|
||||
return (
|
||||
<Flex {...rest} sx={optionSx} data-is-compact={compactView}>
|
||||
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
|
||||
<Flex flexDir="column" gap={1} flex={1}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{option.starred && (
|
||||
<PiStarFill color="yellow" size={16} />
|
||||
)}
|
||||
<Text sx={optionNameSx} data-is-compact={compactView}>
|
||||
{displayName}
|
||||
{option.name}
|
||||
</Text>
|
||||
<Spacer />
|
||||
{option.file_size > 0 && (
|
||||
@@ -378,7 +363,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
|
||||
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
|
||||
};
|
||||
|
||||
const isMatch = (model: AnyModelConfig, searchTerm: string) => {
|
||||
const isMatch = <T extends AnyModelConfig>(model: T & { starred?: boolean }, searchTerm: string) => {
|
||||
const regex = getRegex(searchTerm);
|
||||
const bases = BASE_KEYWORDS[model.base] ?? [model.base];
|
||||
const testString =
|
||||
|
||||
Reference in New Issue
Block a user