mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 21:04:57 -05:00
fix(ui): unstable selector results in lora drop down
This commit is contained in:
@@ -1,12 +1,18 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { uniq } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { useGroupedModelCombobox } from './useGroupedModelCombobox';
|
||||
import { useRelatedModelKeys } from './useRelatedModelKeys';
|
||||
import { useSelectedModelKeys } from './useSelectedModelKeys';
|
||||
|
||||
type UseRelatedGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelConfigs: T[];
|
||||
@@ -29,6 +35,32 @@ type UseRelatedGroupedModelComboboxReturn = {
|
||||
noOptionsMessage: () => string;
|
||||
};
|
||||
|
||||
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
|
||||
const keys: string[] = [];
|
||||
const main = params.model;
|
||||
const vae = params.vae;
|
||||
const refiner = params.refinerModel;
|
||||
const controlnet = params.controlLora;
|
||||
|
||||
if (main) {
|
||||
keys.push(main.key);
|
||||
}
|
||||
if (vae) {
|
||||
keys.push(vae.key);
|
||||
}
|
||||
if (refiner) {
|
||||
keys.push(refiner.key);
|
||||
}
|
||||
if (controlnet) {
|
||||
keys.push(controlnet.key);
|
||||
}
|
||||
for (const { model } of loras.loras) {
|
||||
keys.push(model.key);
|
||||
}
|
||||
|
||||
return uniq(keys);
|
||||
});
|
||||
|
||||
export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
|
||||
modelConfigs,
|
||||
selectedModel,
|
||||
@@ -39,9 +71,15 @@ export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
|
||||
}: UseRelatedGroupedModelComboboxArg<T>): UseRelatedGroupedModelComboboxReturn {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const selectedKeys = useSelectedModelKeys();
|
||||
|
||||
const relatedKeys = useRelatedModelKeys(selectedKeys);
|
||||
const selectedKeys = useAppSelector(selectSelectedModelKeys);
|
||||
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { relatedKeys: EMPTY_ARRAY };
|
||||
}
|
||||
return { relatedKeys: data };
|
||||
},
|
||||
});
|
||||
|
||||
// Base grouped options
|
||||
const base = useGroupedModelCombobox({
|
||||
@@ -53,40 +91,42 @@ export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
|
||||
groupByType,
|
||||
});
|
||||
|
||||
// If no related models selected, just return base
|
||||
if (relatedKeys.size === 0) {
|
||||
return base;
|
||||
}
|
||||
const options = useMemo(() => {
|
||||
if (relatedKeys.length === 0) {
|
||||
return base.options;
|
||||
}
|
||||
|
||||
const relatedOptions: ComboboxOption[] = [];
|
||||
const updatedGroups: GroupBase<ComboboxOption>[] = [];
|
||||
const relatedOptions: ComboboxOption[] = [];
|
||||
const updatedGroups: GroupBase<ComboboxOption>[] = [];
|
||||
|
||||
for (const group of base.options) {
|
||||
const remainingOptions: ComboboxOption[] = [];
|
||||
for (const group of base.options) {
|
||||
const remainingOptions: ComboboxOption[] = [];
|
||||
|
||||
for (const option of group.options) {
|
||||
if (relatedKeys.has(option.value)) {
|
||||
relatedOptions.push({ ...option, label: `* ${option.label}` });
|
||||
} else {
|
||||
remainingOptions.push(option);
|
||||
for (const option of group.options) {
|
||||
if (relatedKeys.includes(option.value)) {
|
||||
relatedOptions.push({ ...option, label: `* ${option.label}` });
|
||||
} else {
|
||||
remainingOptions.push(option);
|
||||
}
|
||||
}
|
||||
|
||||
if (remainingOptions.length > 0) {
|
||||
updatedGroups.push({
|
||||
label: group.label,
|
||||
options: remainingOptions,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (remainingOptions.length > 0) {
|
||||
updatedGroups.push({
|
||||
label: group.label,
|
||||
options: remainingOptions,
|
||||
});
|
||||
if (relatedOptions.length > 0) {
|
||||
return [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups];
|
||||
} else {
|
||||
return updatedGroups;
|
||||
}
|
||||
}
|
||||
|
||||
const finalOptions: GroupBase<ComboboxOption>[] =
|
||||
relatedOptions.length > 0
|
||||
? [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups]
|
||||
: updatedGroups;
|
||||
}, [base.options, relatedKeys, t]);
|
||||
|
||||
return {
|
||||
...base,
|
||||
options: finalOptions,
|
||||
options,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
|
||||
const options: Parameters<typeof useGetRelatedModelIdsBatchQuery>[1] = {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { related: EMPTY_ARRAY };
|
||||
}
|
||||
return data;
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches related model keys for a given set of selected model keys.
|
||||
* Returns a Set<string> for fast lookup.
|
||||
*/
|
||||
export const useRelatedModelKeys = (selectedKeys: Set<string>) => {
|
||||
const { data: related = [] } = useGetRelatedModelIdsBatchQuery([...selectedKeys], {
|
||||
skip: selectedKeys.size === 0,
|
||||
});
|
||||
export const useRelatedModelKeys = (selectedKeys: string[]) => {
|
||||
const { related } = useGetRelatedModelIdsBatchQuery(selectedKeys, options);
|
||||
|
||||
return useMemo(() => new Set(related), [related]);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user