diff --git a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx index d0e3271c31..e5ff35971e 100644 --- a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx +++ b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx @@ -91,6 +91,10 @@ const isGroup = (optionOrGroup: OptionOrGroup): optionOrGro return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true; }; +export const isOption = (optionOrGroup: OptionOrGroup): optionOrGroup is T => { + return !(uniqueGroupKey in optionOrGroup); +}; + const DefaultOptionComponent = typedMemo(({ option }: { option: T }) => { const { getOptionId } = usePickerContext(); return {getOptionId(option)}; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index acec0f0061..6426388e21 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -4,67 +4,29 @@ import { EMPTY_ARRAY } from 'app/store/constants'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import type { GroupStatusMap } from 'common/components/Picker/Picker'; -import { uniq } from 'es-toolkit/compat'; import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; -import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { ModelPicker } from 'features/parameters/components/ModelPicker'; import { API_BASE_MODELS } from 'features/parameters/types/constants'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; import { useLoRAModels } from 'services/api/hooks/modelsByType'; import type { LoRAModelConfig } from 'services/api/types'; const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras); -const selectSelectedModelKeys = createSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => { - const keys: string[] = []; - const main = params.model; - const vae = params.vae; - const refiner = params.refinerModel; - const controlnet = params.controlLora; - - if (main) { - keys.push(main.key); - } - if (vae) { - keys.push(vae.key); - } - if (refiner) { - keys.push(refiner.key); - } - if (controlnet) { - keys.push(controlnet.key); - } - for (const { model } of loras.loras) { - keys.push(model.key); - } - - return uniq(keys); -}); - const LoRASelect = () => { const dispatch = useAppDispatch(); const [modelConfigs, { isLoading }] = useLoRAModels(); const { t } = useTranslation(); const addedLoRAs = useAppSelector(selectLoRAs); - const selectedKeys = useAppSelector(selectSelectedModelKeys); - const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, { - selectFromResult: ({ data }) => { - if (!data) { - return { relatedKeys: EMPTY_ARRAY }; - } - return { relatedKeys: data }; - }, - }); - - const currentBaseModel = useAppSelector((state) => state.params.model?.base); + const currentBaseModel = useAppSelector(selectBase); // Filter to only show compatible LoRAs const compatibleLoRAs = useMemo(() => { if (!currentBaseModel) { - return []; + return EMPTY_ARRAY; } return modelConfigs.filter((model) => model.base === currentBaseModel); }, [modelConfigs, currentBaseModel]); @@ -121,7 +83,6 @@ const LoRASelect = () => { modelConfigs={compatibleLoRAs} onChange={onChange} grouped={false} - relatedModelKeys={relatedKeys} selectedModelConfig={undefined} allowEmpty placeholder={placeholder} diff --git a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx index 50e4837118..5a4adbde86 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx @@ -2,6 +2,7 @@ import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-lib import { Button, Flex, + Icon, Popover, PopoverArrow, PopoverBody, @@ -12,12 +13,17 @@ import { Text, } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { Group, PickerContextState } from 'common/components/Picker/Picker'; -import { buildGroup, getRegex, Picker, usePickerContext } from 'common/components/Picker/Picker'; +import { buildGroup, getRegex, isOption, Picker, usePickerContext } from 'common/components/Picker/Picker'; import { useDisclosure } from 'common/hooks/useBoolean'; import { typedMemo } from 'common/util/typedMemo'; +import { uniq } from 'es-toolkit/compat'; +import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; +import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore'; import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage'; @@ -29,10 +35,39 @@ import { filesize } from 'filesize'; import { memo, useCallback, useMemo, useRef } from 'react'; import { Trans, useTranslation } from 'react-i18next'; import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi'; +import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; import type { AnyModelConfig, BaseModelType } from 'services/api/types'; +const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => { + const keys: string[] = []; + const main = params.model; + const vae = params.vae; + const refiner = params.refinerModel; + const controlnet = params.controlLora; + + if (main) { + keys.push(main.key); + } + if (vae) { + keys.push(vae.key); + } + if (refiner) { + keys.push(refiner.key); + } + if (controlnet) { + keys.push(controlnet.key); + } + for (const { model } of loras.loras) { + keys.push(model.key); + } + + return uniq(keys); +}); + +type WithStarred = T & { starred?: boolean }; + // Type for models with starred field -const getOptionId = (modelConfig: T & { starred?: boolean }) => modelConfig.key; +const getOptionId = (modelConfig: WithStarred) => modelConfig.key; const ModelManagerLink = memo((props: ButtonProps) => { const onClickGoToModelManager = useStore($onClickGoToModelManager); @@ -105,6 +140,15 @@ const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string return BASE_COLOR_MAP[modelConfig.base]; }; +const relatedModelKeysQueryOptions = { + selectFromResult: ({ data }) => { + if (!data) { + return { relatedModelKeys: EMPTY_ARRAY }; + } + return { relatedModelKeys: data }; + }, +} satisfies Parameters[1]; + const popperModifiers = [ { // Prevents the popover from "touching" the edges of the screen @@ -113,13 +157,17 @@ const popperModifiers = [ }, ]; +const removeStarred = (obj: WithStarred): T => { + const { starred: _, ...rest } = obj; + return rest as T; +}; + export const ModelPicker = typedMemo( ({ modelConfigs, selectedModelConfig, onChange, grouped, - relatedModelKeys = [], getIsOptionDisabled, placeholder, allowEmpty, @@ -133,7 +181,6 @@ export const ModelPicker = typedMemo( selectedModelConfig: T | undefined; onChange: (modelConfig: T) => void; grouped?: boolean; - relatedModelKeys?: string[]; getIsOptionDisabled?: (model: T) => boolean; placeholder?: string; allowEmpty?: boolean; @@ -144,7 +191,11 @@ export const ModelPicker = typedMemo( initialGroupStates?: Record; }) => { const { t } = useTranslation(); - const options = useMemo<(T & { starred?: boolean })[] | Group[]>(() => { + const selectedKeys = useAppSelector(selectSelectedModelKeys); + + const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, relatedModelKeysQueryOptions); + + const options = useMemo[] | Group>[]>(() => { if (!grouped) { // Add starred field to model options and sort them const modelsWithStarred = modelConfigs.map((model) => ({ @@ -165,13 +216,13 @@ export const ModelPicker = typedMemo( } // When all groups are disabled, we show all models - const groups: Record> = {}; + const groups: Record>> = {}; for (const modelConfig of modelConfigs) { const groupId = getGroupIDFromModelConfig(modelConfig); let group = groups[groupId]; if (!group) { - group = buildGroup({ + group = buildGroup>({ id: modelConfig.base, color: `${getGroupColorSchemeFromModelConfig(modelConfig)}.300`, shortName: getGroupShortNameFromModelConfig(modelConfig), @@ -191,7 +242,7 @@ export const ModelPicker = typedMemo( } } - const _options: Group[] = []; + const _options: Group>[] = []; // Add groups in the original order for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) { @@ -216,7 +267,15 @@ export const ModelPicker = typedMemo( return _options; }, [grouped, modelConfigs, relatedModelKeys, t]); const popover = useDisclosure(false); - const pickerRef = useRef>(null); + const pickerRef = useRef>>(null); + + const selectedOption = useMemo | undefined>(() => { + if (!selectedModelConfig) { + return undefined; + } + + return options.filter(isOption).find((o) => o.key === selectedModelConfig.key); + }, [options, selectedModelConfig]); const onClose = useCallback(() => { popover.close(); @@ -224,11 +283,10 @@ export const ModelPicker = typedMemo( }, [popover]); const onSelect = useCallback( - (model: T & { starred?: boolean }) => { + (model: WithStarred) => { onClose(); // Remove the starred field before passing to onChange - const { starred: _, ...modelWithoutStarred } = model; - onChange(modelWithoutStarred as T); + onChange(removeStarred(model)); }, [onChange, onClose] ); @@ -268,17 +326,13 @@ export const ModelPicker = typedMemo( - - + + > handleRef={pickerRef} optionsOrGroups={options} getOptionId={getOptionId} onSelect={onSelect} - selectedOption={ - selectedModelConfig - ? { ...selectedModelConfig, starred: relatedModelKeys.includes(selectedModelConfig.key) } - : undefined - } + selectedOption={selectedOption} isMatch={isMatch} OptionComponent={PickerOptionComponent} noOptionsFallback={} @@ -332,8 +386,8 @@ const optionNameSx: SystemStyleObject = { }; const PickerOptionComponent = typedMemo( - ({ option, ...rest }: { option: T & { starred?: boolean } } & BoxProps) => { - const { $compactView } = usePickerContext(); + ({ option, ...rest }: { option: WithStarred } & BoxProps) => { + const { $compactView } = usePickerContext>(); const compactView = useStore($compactView); return ( @@ -341,7 +395,7 @@ const PickerOptionComponent = typedMemo( {!compactView && option.cover_image && } - {option.starred && } + {option.starred && } {option.name} @@ -371,7 +425,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = { 'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'], }; -const isMatch = (model: T & { starred?: boolean }, searchTerm: string) => { +const isMatch = (model: WithStarred, searchTerm: string) => { const regex = getRegex(searchTerm); const bases = BASE_KEYWORDS[model.base] ?? [model.base]; const testString =