From 571d286506938abf315d39d2cc8d391eeacadcb1 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 26 Jun 2025 22:19:40 +0000 Subject: [PATCH] Enhance LoRA picker to default to current base model architecture Co-authored-by: kent Enhance LoRA picker to default filter by current base model architecture Co-authored-by: kent --- .../src/common/components/Picker/Picker.tsx | 23 +++++++++++++++---- .../features/lora/components/LoRASelect.tsx | 15 ++++++++++++ .../parameters/components/ModelPicker.tsx | 3 +++ 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx index 14b2476a45..e3c0d35e10 100644 --- a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx +++ b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx @@ -198,6 +198,10 @@ type PickerProps = { * Whether the picker should be searchable. If true, renders a search input. */ searchable?: boolean; + /** + * Initial state for group toggles. If provided, groups will start with these states instead of all being disabled. + */ + initialGroupStates?: Record; }; export type PickerContextState = { @@ -312,7 +316,10 @@ const flattenOptions = (options: OptionOrGroup[]): T[] => { type GroupStatusMap = Record; -const useTogglableGroups = (options: OptionOrGroup[]) => { +const useTogglableGroups = ( + options: OptionOrGroup[], + initialGroupStates?: Record +) => { const groupsWithOptions = useMemo(() => { const ids: string[] = []; for (const optionOrGroup of options) { @@ -332,14 +339,16 @@ const useTogglableGroups = (options: OptionOrGroup[]) => { const groupStatusMap = $groupStatusMap.get(); const newMap: GroupStatusMap = {}; for (const id of groupsWithOptions) { - if (newMap[id] === undefined) { - newMap[id] = false; + if (initialGroupStates && initialGroupStates[id] !== undefined) { + newMap[id] = initialGroupStates[id]; } else if (groupStatusMap[id] !== undefined) { newMap[id] = groupStatusMap[id]; + } else { + newMap[id] = false; } } $groupStatusMap.set(newMap); - }, [groupsWithOptions, $groupStatusMap]); + }, [groupsWithOptions, $groupStatusMap, initialGroupStates]); const toggleGroup = useCallback( (idToToggle: string) => { @@ -511,10 +520,14 @@ export const Picker = typedMemo((props: PickerProps) => { OptionComponent = DefaultOptionComponent, NextToSearchBar, searchable, + initialGroupStates, } = props; const rootRef = useRef(null); const inputRef = useRef(null); - const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups); + const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups( + optionsOrGroups, + initialGroupStates + ); const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId)); const $compactView = useAtom(true); const $optionsOrGroups = useAtom(optionsOrGroups); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index f49627b34f..4ab04783df 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -6,6 +6,7 @@ import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedMo import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { ModelPicker } from 'features/parameters/components/ModelPicker'; +import { API_BASE_MODELS } from 'features/parameters/types/constants'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useLoRAModels } from 'services/api/hooks/modelsByType'; @@ -58,6 +59,19 @@ const LoRASelect = () => { return t('models.addLora'); }, [isLoading, options.length, t]); + // Calculate initial group states to default to the current base model architecture + const initialGroupStates = useMemo(() => { + if (!currentBaseModel) { + return undefined; + } + + // Determine the group ID for the current base model + const groupId = API_BASE_MODELS.includes(currentBaseModel) ? 'api' : currentBaseModel; + + // Return a map with only the current base model group enabled + return { [groupId]: true }; + }, [currentBaseModel]); + return ( @@ -72,6 +86,7 @@ const LoRASelect = () => { placeholder={placeholder} getIsOptionDisabled={getIsDisabled} noOptionsText={t('models.noLoRAsInstalled')} + initialGroupStates={initialGroupStates} /> ); diff --git a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx index 501e40b3a4..cd8c18c5c5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx @@ -125,6 +125,7 @@ export const ModelPicker = typedMemo( isInvalid, className, noOptionsText, + initialGroupStates, }: { modelConfigs: T[]; selectedModelConfig: T | undefined; @@ -137,6 +138,7 @@ export const ModelPicker = typedMemo( isInvalid?: boolean; className?: string; noOptionsText?: string; + initialGroupStates?: Record; }) => { const { t } = useTranslation(); const options = useMemo[]>(() => { @@ -244,6 +246,7 @@ export const ModelPicker = typedMemo( NextToSearchBar={} getIsOptionDisabled={getIsOptionDisabled} searchable + initialGroupStates={initialGroupStates} />