diff --git a/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts index af14f5460e..ba06451256 100644 --- a/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts @@ -1,12 +1,18 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; +import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; +import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { uniq } from 'lodash-es'; +import { useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; import type { AnyModelConfig } from 'services/api/types'; import { useGroupedModelCombobox } from './useGroupedModelCombobox'; -import { useRelatedModelKeys } from './useRelatedModelKeys'; -import { useSelectedModelKeys } from './useSelectedModelKeys'; type UseRelatedGroupedModelComboboxArg = { modelConfigs: T[]; @@ -29,6 +35,32 @@ type UseRelatedGroupedModelComboboxReturn = { noOptionsMessage: () => string; }; +const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => { + const keys: string[] = []; + const main = params.model; + const vae = params.vae; + const refiner = params.refinerModel; + const controlnet = params.controlLora; + + if (main) { + keys.push(main.key); + } + if (vae) { + keys.push(vae.key); + } + if (refiner) { + keys.push(refiner.key); + } + if (controlnet) { + keys.push(controlnet.key); + } + for (const { model } of loras.loras) { + keys.push(model.key); + } + + return uniq(keys); +}); + export function useRelatedGroupedModelCombobox({ modelConfigs, selectedModel, @@ -39,9 +71,15 @@ export function useRelatedGroupedModelCombobox({ }: UseRelatedGroupedModelComboboxArg): UseRelatedGroupedModelComboboxReturn { const { t } = useTranslation(); - const selectedKeys = useSelectedModelKeys(); - - const relatedKeys = useRelatedModelKeys(selectedKeys); + const selectedKeys = useAppSelector(selectSelectedModelKeys); + const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, { + selectFromResult: ({ data }) => { + if (!data) { + return { relatedKeys: EMPTY_ARRAY }; + } + return { relatedKeys: data }; + }, + }); // Base grouped options const base = useGroupedModelCombobox({ @@ -53,40 +91,42 @@ export function useRelatedGroupedModelCombobox({ groupByType, }); - // If no related models selected, just return base - if (relatedKeys.size === 0) { - return base; - } + const options = useMemo(() => { + if (relatedKeys.length === 0) { + return base.options; + } - const relatedOptions: ComboboxOption[] = []; - const updatedGroups: GroupBase[] = []; + const relatedOptions: ComboboxOption[] = []; + const updatedGroups: GroupBase[] = []; - for (const group of base.options) { - const remainingOptions: ComboboxOption[] = []; + for (const group of base.options) { + const remainingOptions: ComboboxOption[] = []; - for (const option of group.options) { - if (relatedKeys.has(option.value)) { - relatedOptions.push({ ...option, label: `* ${option.label}` }); - } else { - remainingOptions.push(option); + for (const option of group.options) { + if (relatedKeys.includes(option.value)) { + relatedOptions.push({ ...option, label: `* ${option.label}` }); + } else { + remainingOptions.push(option); + } + } + + if (remainingOptions.length > 0) { + updatedGroups.push({ + label: group.label, + options: remainingOptions, + }); } } - if (remainingOptions.length > 0) { - updatedGroups.push({ - label: group.label, - options: remainingOptions, - }); + if (relatedOptions.length > 0) { + return [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups]; + } else { + return updatedGroups; } - } - - const finalOptions: GroupBase[] = - relatedOptions.length > 0 - ? [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups] - : updatedGroups; + }, [base.options, relatedKeys, t]); return { ...base, - options: finalOptions, + options, }; } diff --git a/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts b/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts index fc0711b969..9349c9b63d 100644 --- a/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts +++ b/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts @@ -1,14 +1,22 @@ +import { EMPTY_ARRAY } from 'app/store/constants'; import { useMemo } from 'react'; import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; +const options: Parameters[1] = { + selectFromResult: ({ data }) => { + if (!data) { + return { related: EMPTY_ARRAY }; + } + return data; + }, +}; + /** * Fetches related model keys for a given set of selected model keys. * Returns a Set for fast lookup. */ -export const useRelatedModelKeys = (selectedKeys: Set) => { - const { data: related = [] } = useGetRelatedModelIdsBatchQuery([...selectedKeys], { - skip: selectedKeys.size === 0, - }); +export const useRelatedModelKeys = (selectedKeys: string[]) => { + const { related } = useGetRelatedModelIdsBatchQuery(selectedKeys, options); return useMemo(() => new Set(related), [related]); };