From 5fa6c0b4136884acaf4c1feb46fd599d85c9f7ad Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 27 Jun 2025 01:16:52 +0000 Subject: [PATCH] Enhance model picker with related models and improved filtering Co-authored-by: kent --- .../features/lora/components/LoRASelect.tsx | 56 +++++++++++++---- .../parameters/components/ModelPicker.tsx | 60 +++++++++++++++++-- 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 293acde364..82e64a97a3 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -1,26 +1,65 @@ import { FormControl, FormLabel } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; +import { EMPTY_ARRAY } from 'app/store/constants'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import type { GroupStatusMap } from 'common/components/Picker/Picker'; -import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox'; +import { uniq } from 'es-toolkit/compat'; import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; -import { selectBase } from 'features/controlLayers/store/paramsSlice'; +import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { ModelPicker } from 'features/parameters/components/ModelPicker'; import { API_BASE_MODELS } from 'features/parameters/types/constants'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; import { useLoRAModels } from 'services/api/hooks/modelsByType'; import type { LoRAModelConfig } from 'services/api/types'; const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras); +const selectSelectedModelKeys = createSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => { + const keys: string[] = []; + const main = params.model; + const vae = params.vae; + const refiner = params.refinerModel; + const controlnet = params.controlLora; + + if (main) { + keys.push(main.key); + } + if (vae) { + keys.push(vae.key); + } + if (refiner) { + keys.push(refiner.key); + } + if (controlnet) { + keys.push(controlnet.key); + } + for (const { model } of loras.loras) { + keys.push(model.key); + } + + return uniq(keys); +}); + const LoRASelect = () => { const dispatch = useAppDispatch(); const [modelConfigs, { isLoading }] = useLoRAModels(); const { t } = useTranslation(); const addedLoRAs = useAppSelector(selectLoRAs); - const currentBaseModel = useAppSelector(selectBase); + const selectedKeys = useAppSelector(selectSelectedModelKeys); + + const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, { + selectFromResult: ({ data }) => { + if (!data) { + return { relatedKeys: EMPTY_ARRAY }; + } + return { relatedKeys: data }; + }, + }); + + const currentBaseModel = useAppSelector((state) => state.controlLayers.present.params.model?.base); const getIsDisabled = useCallback( (model: LoRAModelConfig): boolean => { @@ -42,23 +81,17 @@ const LoRASelect = () => { [dispatch] ); - const { options } = useRelatedGroupedModelCombobox({ - modelConfigs, - getIsDisabled, - onChange, - }); - const placeholder = useMemo(() => { if (isLoading) { return t('common.loading'); } - if (options.length === 0) { + if (modelConfigs.length === 0) { return t('models.noLoRAsInstalled'); } return t('models.addLora'); - }, [isLoading, options.length, t]); + }, [isLoading, modelConfigs.length, t]); // Calculate initial group states to default to the current base model architecture const initialGroupStates = useMemo(() => { @@ -82,6 +115,7 @@ const LoRASelect = () => { modelConfigs={modelConfigs} onChange={onChange} grouped + relatedModelKeys={relatedKeys} selectedModelConfig={undefined} allowEmpty placeholder={placeholder} diff --git a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx index c642491e4e..4b248e40db 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx @@ -118,6 +118,7 @@ export const ModelPicker = typedMemo( selectedModelConfig, onChange, grouped, + relatedModelKeys = [], getIsOptionDisabled, placeholder, allowEmpty, @@ -131,6 +132,7 @@ export const ModelPicker = typedMemo( selectedModelConfig: T | undefined; onChange: (modelConfig: T) => void; grouped?: boolean; + relatedModelKeys?: string[]; getIsOptionDisabled?: (model: T) => boolean; placeholder?: string; allowEmpty?: boolean; @@ -143,13 +145,35 @@ export const ModelPicker = typedMemo( const { t } = useTranslation(); const options = useMemo[]>(() => { if (!grouped) { + // Handle related models for non-grouped view + if (relatedModelKeys.length > 0) { + const relatedModels: T[] = []; + const otherModels: T[] = []; + + for (const modelConfig of modelConfigs) { + if (relatedModelKeys.includes(modelConfig.key)) { + relatedModels.push(modelConfig); + } else { + otherModels.push(modelConfig); + } + } + + return [...relatedModels, ...otherModels]; + } return modelConfigs; } // When all groups are disabled, we show all models const groups: Record> = {}; + const relatedModels: T[] = []; for (const modelConfig of modelConfigs) { + // Check if this model is related and separate it + if (relatedModelKeys.length > 0 && relatedModelKeys.includes(modelConfig.key)) { + relatedModels.push(modelConfig); + continue; + } + const groupId = getGroupIDFromModelConfig(modelConfig); let group = groups[groupId]; if (!group) { @@ -170,6 +194,20 @@ export const ModelPicker = typedMemo( const _options: Group[] = []; + // Add related models group first if there are any + if (relatedModels.length > 0) { + const relatedGroup = buildGroup({ + id: 'related', + color: 'accent.300', + shortName: t('modelManager.showOnlyRelatedModels'), + name: t('modelManager.relatedModels'), + getOptionCountString: (count) => t('common.model_withCount', { count }), + options: relatedModels, + }); + _options.push(relatedGroup); + } + + // Add other groups in the original order for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) { const group = groups[groupId]; if (group) { @@ -180,7 +218,7 @@ export const ModelPicker = typedMemo( _options.push(...Object.values(groups)); return _options; - }, [grouped, modelConfigs, t]); + }, [grouped, modelConfigs, relatedModelKeys, t]); const popover = useDisclosure(false); const pickerRef = useRef>(null); @@ -207,6 +245,18 @@ export const ModelPicker = typedMemo( return undefined; }, [allowEmpty, isInvalid, selectedModelConfig]); + // Create a component wrapper that includes related model styling + const RelatedModelPickerOptionComponent = useCallback( + ({ option, ...rest }: { option: T } & BoxProps) => ( + + ), + [relatedModelKeys] + ); + return ( } noMatchesFallback={t('modelManager.noMatchingModels')} NextToSearchBar={} @@ -291,17 +341,19 @@ const optionNameSx: SystemStyleObject = { }, }; -const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => { +const PickerOptionComponent = typedMemo(({ option, isRelated = false, ...rest }: { option: AnyModelConfig; isRelated?: boolean } & BoxProps) => { const { $compactView } = usePickerContext(); const compactView = useStore($compactView); + const displayName = isRelated ? `* ${option.name}` : option.name; + return ( {!compactView && option.cover_image && } - {option.name} + {displayName} {option.file_size > 0 && (