Enhance model picker with related models and improved filtering

Co-authored-by: kent <kent@invoke.ai>
This commit is contained in:
Cursor Agent
2025-06-27 01:16:52 +00:00
committed by psychedelicious
parent c37c8c50cd
commit 5fa6c0b413
2 changed files with 101 additions and 15 deletions

View File

@@ -1,26 +1,65 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { GroupStatusMap } from 'common/components/Picker/Picker';
import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox';
import { uniq } from 'es-toolkit/compat';
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
const selectSelectedModelKeys = createSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
const main = params.model;
const vae = params.vae;
const refiner = params.refinerModel;
const controlnet = params.controlLora;
if (main) {
keys.push(main.key);
}
if (vae) {
keys.push(vae.key);
}
if (refiner) {
keys.push(refiner.key);
}
if (controlnet) {
keys.push(controlnet.key);
}
for (const { model } of loras.loras) {
keys.push(model.key);
}
return uniq(keys);
});
const LoRASelect = () => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLoRAModels();
const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectLoRAs);
const currentBaseModel = useAppSelector(selectBase);
const selectedKeys = useAppSelector(selectSelectedModelKeys);
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedKeys: EMPTY_ARRAY };
}
return { relatedKeys: data };
},
});
const currentBaseModel = useAppSelector((state) => state.controlLayers.present.params.model?.base);
const getIsDisabled = useCallback(
(model: LoRAModelConfig): boolean => {
@@ -42,23 +81,17 @@ const LoRASelect = () => {
[dispatch]
);
const { options } = useRelatedGroupedModelCombobox({
modelConfigs,
getIsDisabled,
onChange,
});
const placeholder = useMemo(() => {
if (isLoading) {
return t('common.loading');
}
if (options.length === 0) {
if (modelConfigs.length === 0) {
return t('models.noLoRAsInstalled');
}
return t('models.addLora');
}, [isLoading, options.length, t]);
}, [isLoading, modelConfigs.length, t]);
// Calculate initial group states to default to the current base model architecture
const initialGroupStates = useMemo(() => {
@@ -82,6 +115,7 @@ const LoRASelect = () => {
modelConfigs={modelConfigs}
onChange={onChange}
grouped
relatedModelKeys={relatedKeys}
selectedModelConfig={undefined}
allowEmpty
placeholder={placeholder}

View File

@@ -118,6 +118,7 @@ export const ModelPicker = typedMemo(
selectedModelConfig,
onChange,
grouped,
relatedModelKeys = [],
getIsOptionDisabled,
placeholder,
allowEmpty,
@@ -131,6 +132,7 @@ export const ModelPicker = typedMemo(
selectedModelConfig: T | undefined;
onChange: (modelConfig: T) => void;
grouped?: boolean;
relatedModelKeys?: string[];
getIsOptionDisabled?: (model: T) => boolean;
placeholder?: string;
allowEmpty?: boolean;
@@ -143,13 +145,35 @@ export const ModelPicker = typedMemo(
const { t } = useTranslation();
const options = useMemo<T[] | Group<T>[]>(() => {
if (!grouped) {
// Handle related models for non-grouped view
if (relatedModelKeys.length > 0) {
const relatedModels: T[] = [];
const otherModels: T[] = [];
for (const modelConfig of modelConfigs) {
if (relatedModelKeys.includes(modelConfig.key)) {
relatedModels.push(modelConfig);
} else {
otherModels.push(modelConfig);
}
}
return [...relatedModels, ...otherModels];
}
return modelConfigs;
}
// When all groups are disabled, we show all models
const groups: Record<string, Group<T>> = {};
const relatedModels: T[] = [];
for (const modelConfig of modelConfigs) {
// Check if this model is related and separate it
if (relatedModelKeys.length > 0 && relatedModelKeys.includes(modelConfig.key)) {
relatedModels.push(modelConfig);
continue;
}
const groupId = getGroupIDFromModelConfig(modelConfig);
let group = groups[groupId];
if (!group) {
@@ -170,6 +194,20 @@ export const ModelPicker = typedMemo(
const _options: Group<T>[] = [];
// Add related models group first if there are any
if (relatedModels.length > 0) {
const relatedGroup = buildGroup<T>({
id: 'related',
color: 'accent.300',
shortName: t('modelManager.showOnlyRelatedModels'),
name: t('modelManager.relatedModels'),
getOptionCountString: (count) => t('common.model_withCount', { count }),
options: relatedModels,
});
_options.push(relatedGroup);
}
// Add other groups in the original order
for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) {
const group = groups[groupId];
if (group) {
@@ -180,7 +218,7 @@ export const ModelPicker = typedMemo(
_options.push(...Object.values(groups));
return _options;
}, [grouped, modelConfigs, t]);
}, [grouped, modelConfigs, relatedModelKeys, t]);
const popover = useDisclosure(false);
const pickerRef = useRef<PickerContextState<T>>(null);
@@ -207,6 +245,18 @@ export const ModelPicker = typedMemo(
return undefined;
}, [allowEmpty, isInvalid, selectedModelConfig]);
// Create a component wrapper that includes related model styling
const RelatedModelPickerOptionComponent = useCallback(
({ option, ...rest }: { option: T } & BoxProps) => (
<PickerOptionComponent
option={option}
isRelated={relatedModelKeys.includes(option.key)}
{...rest}
/>
),
[relatedModelKeys]
);
return (
<Popover
isOpen={popover.isOpen}
@@ -240,7 +290,7 @@ export const ModelPicker = typedMemo(
onSelect={onSelect}
selectedOption={selectedModelConfig}
isMatch={isMatch}
OptionComponent={PickerOptionComponent}
OptionComponent={RelatedModelPickerOptionComponent}
noOptionsFallback={<NoOptionsFallback noOptionsText={noOptionsText} />}
noMatchesFallback={t('modelManager.noMatchingModels')}
NextToSearchBar={<NavigateToModelManagerButton />}
@@ -291,17 +341,19 @@ const optionNameSx: SystemStyleObject = {
},
};
const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => {
const PickerOptionComponent = typedMemo(({ option, isRelated = false, ...rest }: { option: AnyModelConfig; isRelated?: boolean } & BoxProps) => {
const { $compactView } = usePickerContext<AnyModelConfig>();
const compactView = useStore($compactView);
const displayName = isRelated ? `* ${option.name}` : option.name;
return (
<Flex {...rest} sx={optionSx} data-is-compact={compactView}>
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
<Flex flexDir="column" gap={1} flex={1}>
<Flex gap={2} alignItems="center">
<Text sx={optionNameSx} data-is-compact={compactView}>
{option.name}
{displayName}
</Text>
<Spacer />
{option.file_size > 0 && (