diff --git a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx index fc08047e27..709d0c94df 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx @@ -28,10 +28,13 @@ import { setActiveTab } from 'features/ui/store/uiSlice'; import { filesize } from 'filesize'; import { memo, useCallback, useMemo, useRef } from 'react'; import { Trans, useTranslation } from 'react-i18next'; -import { PiCaretDownBold } from 'react-icons/pi'; +import { PiCaretDownBold, PiStarFill } from 'react-icons/pi'; import type { AnyModelConfig, BaseModelType } from 'services/api/types'; -const getOptionId = (modelConfig: AnyModelConfig) => modelConfig.key; +// Type for models with starred field +type StarredModelConfig = AnyModelConfig & { starred?: boolean }; + +const getOptionId = (modelConfig: T & { starred?: boolean }) => modelConfig.key; const ModelManagerLink = memo((props: ButtonProps) => { const onClickGoToModelManager = useStore($onClickGoToModelManager); @@ -143,41 +146,30 @@ export const ModelPicker = typedMemo( initialGroupStates?: Record; }) => { const { t } = useTranslation(); - const options = useMemo[]>(() => { + const options = useMemo<(T & { starred?: boolean })[] | Group[]>(() => { if (!grouped) { - // Handle related models for non-grouped view - if (relatedModelKeys.length > 0) { - const relatedModels: T[] = []; - const otherModels: T[] = []; - - for (const modelConfig of modelConfigs) { - if (relatedModelKeys.includes(modelConfig.key)) { - relatedModels.push(modelConfig); - } else { - otherModels.push(modelConfig); - } - } - - return [...relatedModels, ...otherModels]; - } - return modelConfigs; + // Add starred field to model options and sort them + const modelsWithStarred = modelConfigs.map(model => ({ + ...model, + starred: relatedModelKeys.includes(model.key) + })); + + // Sort so starred models come first + return modelsWithStarred.sort((a, b) => { + if (a.starred && !b.starred) return -1; + if (!a.starred && b.starred) return 1; + return 0; + }); } // When all groups are disabled, we show all models - const groups: Record> = {}; - const relatedModels: T[] = []; + const groups: Record> = {}; for (const modelConfig of modelConfigs) { - // Check if this model is related and separate it - if (relatedModelKeys.length > 0 && relatedModelKeys.includes(modelConfig.key)) { - relatedModels.push(modelConfig); - continue; - } - const groupId = getGroupIDFromModelConfig(modelConfig); let group = groups[groupId]; if (!group) { - group = buildGroup({ + group = buildGroup({ id: modelConfig.base, color: `${getGroupColorSchemeFromModelConfig(modelConfig)}.300`, shortName: getGroupShortNameFromModelConfig(modelConfig), @@ -188,29 +180,27 @@ export const ModelPicker = typedMemo( groups[groupId] = group; } if (group) { - group.options.push(modelConfig); + // Add starred field to the model + const modelWithStarred = { + ...modelConfig, + starred: relatedModelKeys.includes(modelConfig.key) + }; + group.options.push(modelWithStarred); } } - const _options: Group[] = []; + const _options: Group[] = []; - // Add related models group first if there are any - if (relatedModels.length > 0) { - const relatedGroup = buildGroup({ - id: 'related', - color: 'accent.300', - shortName: t('modelManager.showOnlyRelatedModels'), - name: t('modelManager.relatedModels'), - getOptionCountString: (count) => t('common.model_withCount', { count }), - options: relatedModels, - }); - _options.push(relatedGroup); - } - - // Add other groups in the original order + // Add groups in the original order for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) { const group = groups[groupId]; if (group) { + // Sort options within each group so starred ones come first + group.options.sort((a, b) => { + if (a.starred && !b.starred) return -1; + if (!a.starred && b.starred) return 1; + return 0; + }); _options.push(group); delete groups[groupId]; } @@ -220,7 +210,7 @@ export const ModelPicker = typedMemo( return _options; }, [grouped, modelConfigs, relatedModelKeys, t]); const popover = useDisclosure(false); - const pickerRef = useRef>(null); + const pickerRef = useRef>(null); const onClose = useCallback(() => { popover.close(); @@ -228,9 +218,11 @@ export const ModelPicker = typedMemo( }, [popover]); const onSelect = useCallback( - (model: T) => { + (model: T & { starred?: boolean }) => { onClose(); - onChange(model); + // Remove the starred field before passing to onChange + const { starred, ...modelWithoutStarred } = model; + onChange(modelWithoutStarred as T); }, [onChange, onClose] ); @@ -245,14 +237,6 @@ export const ModelPicker = typedMemo( return undefined; }, [allowEmpty, isInvalid, selectedModelConfig]); - // Create a component wrapper that includes related model styling - const RelatedModelPickerOptionComponent = useCallback( - ({ option, ...rest }: { option: T } & BoxProps) => ( - - ), - [relatedModelKeys] - ); - return ( - + handleRef={pickerRef} optionsOrGroups={options} - getOptionId={getOptionId} + getOptionId={getOptionId} onSelect={onSelect} - selectedOption={selectedModelConfig} - isMatch={isMatch} - OptionComponent={RelatedModelPickerOptionComponent} + selectedOption={selectedModelConfig ? { ...selectedModelConfig, starred: relatedModelKeys.includes(selectedModelConfig.key) } : undefined} + isMatch={isMatch} + OptionComponent={PickerOptionComponent} noOptionsFallback={} noMatchesFallback={t('modelManager.noMatchingModels')} NextToSearchBar={} @@ -338,19 +322,20 @@ const optionNameSx: SystemStyleObject = { }; const PickerOptionComponent = typedMemo( - ({ option, isRelated = false, ...rest }: { option: AnyModelConfig; isRelated?: boolean } & BoxProps) => { - const { $compactView } = usePickerContext(); + ({ option, ...rest }: { option: T & { starred?: boolean } } & BoxProps) => { + const { $compactView } = usePickerContext(); const compactView = useStore($compactView); - const displayName = isRelated ? `* ${option.name}` : option.name; - return ( {!compactView && option.cover_image && } + {option.starred && ( + + )} - {displayName} + {option.name} {option.file_size > 0 && ( @@ -378,7 +363,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = { 'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'], }; -const isMatch = (model: AnyModelConfig, searchTerm: string) => { +const isMatch = (model: T & { starred?: boolean }, searchTerm: string) => { const regex = getRegex(searchTerm); const bases = BASE_KEYWORDS[model.base] ?? [model.base]; const testString =