mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 22:14:59 -05:00
Enhance model picker with related models and improved filtering
Co-authored-by: kent <kent@invoke.ai>
This commit is contained in:
committed by
psychedelicious
parent
c37c8c50cd
commit
5fa6c0b413
@@ -1,26 +1,65 @@
|
||||
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { GroupStatusMap } from 'common/components/Picker/Picker';
|
||||
import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { ModelPicker } from 'features/parameters/components/ModelPicker';
|
||||
import { API_BASE_MODELS } from 'features/parameters/types/constants';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
|
||||
|
||||
const selectSelectedModelKeys = createSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
|
||||
const keys: string[] = [];
|
||||
const main = params.model;
|
||||
const vae = params.vae;
|
||||
const refiner = params.refinerModel;
|
||||
const controlnet = params.controlLora;
|
||||
|
||||
if (main) {
|
||||
keys.push(main.key);
|
||||
}
|
||||
if (vae) {
|
||||
keys.push(vae.key);
|
||||
}
|
||||
if (refiner) {
|
||||
keys.push(refiner.key);
|
||||
}
|
||||
if (controlnet) {
|
||||
keys.push(controlnet.key);
|
||||
}
|
||||
for (const { model } of loras.loras) {
|
||||
keys.push(model.key);
|
||||
}
|
||||
|
||||
return uniq(keys);
|
||||
});
|
||||
|
||||
const LoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const { t } = useTranslation();
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const selectedKeys = useAppSelector(selectSelectedModelKeys);
|
||||
|
||||
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { relatedKeys: EMPTY_ARRAY };
|
||||
}
|
||||
return { relatedKeys: data };
|
||||
},
|
||||
});
|
||||
|
||||
const currentBaseModel = useAppSelector((state) => state.controlLayers.present.params.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: LoRAModelConfig): boolean => {
|
||||
@@ -42,23 +81,17 @@ const LoRASelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options } = useRelatedGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
getIsDisabled,
|
||||
onChange,
|
||||
});
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return t('common.loading');
|
||||
}
|
||||
|
||||
if (options.length === 0) {
|
||||
if (modelConfigs.length === 0) {
|
||||
return t('models.noLoRAsInstalled');
|
||||
}
|
||||
|
||||
return t('models.addLora');
|
||||
}, [isLoading, options.length, t]);
|
||||
}, [isLoading, modelConfigs.length, t]);
|
||||
|
||||
// Calculate initial group states to default to the current base model architecture
|
||||
const initialGroupStates = useMemo(() => {
|
||||
@@ -82,6 +115,7 @@ const LoRASelect = () => {
|
||||
modelConfigs={modelConfigs}
|
||||
onChange={onChange}
|
||||
grouped
|
||||
relatedModelKeys={relatedKeys}
|
||||
selectedModelConfig={undefined}
|
||||
allowEmpty
|
||||
placeholder={placeholder}
|
||||
|
||||
@@ -118,6 +118,7 @@ export const ModelPicker = typedMemo(
|
||||
selectedModelConfig,
|
||||
onChange,
|
||||
grouped,
|
||||
relatedModelKeys = [],
|
||||
getIsOptionDisabled,
|
||||
placeholder,
|
||||
allowEmpty,
|
||||
@@ -131,6 +132,7 @@ export const ModelPicker = typedMemo(
|
||||
selectedModelConfig: T | undefined;
|
||||
onChange: (modelConfig: T) => void;
|
||||
grouped?: boolean;
|
||||
relatedModelKeys?: string[];
|
||||
getIsOptionDisabled?: (model: T) => boolean;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
@@ -143,13 +145,35 @@ export const ModelPicker = typedMemo(
|
||||
const { t } = useTranslation();
|
||||
const options = useMemo<T[] | Group<T>[]>(() => {
|
||||
if (!grouped) {
|
||||
// Handle related models for non-grouped view
|
||||
if (relatedModelKeys.length > 0) {
|
||||
const relatedModels: T[] = [];
|
||||
const otherModels: T[] = [];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
if (relatedModelKeys.includes(modelConfig.key)) {
|
||||
relatedModels.push(modelConfig);
|
||||
} else {
|
||||
otherModels.push(modelConfig);
|
||||
}
|
||||
}
|
||||
|
||||
return [...relatedModels, ...otherModels];
|
||||
}
|
||||
return modelConfigs;
|
||||
}
|
||||
|
||||
// When all groups are disabled, we show all models
|
||||
const groups: Record<string, Group<T>> = {};
|
||||
const relatedModels: T[] = [];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
// Check if this model is related and separate it
|
||||
if (relatedModelKeys.length > 0 && relatedModelKeys.includes(modelConfig.key)) {
|
||||
relatedModels.push(modelConfig);
|
||||
continue;
|
||||
}
|
||||
|
||||
const groupId = getGroupIDFromModelConfig(modelConfig);
|
||||
let group = groups[groupId];
|
||||
if (!group) {
|
||||
@@ -170,6 +194,20 @@ export const ModelPicker = typedMemo(
|
||||
|
||||
const _options: Group<T>[] = [];
|
||||
|
||||
// Add related models group first if there are any
|
||||
if (relatedModels.length > 0) {
|
||||
const relatedGroup = buildGroup<T>({
|
||||
id: 'related',
|
||||
color: 'accent.300',
|
||||
shortName: t('modelManager.showOnlyRelatedModels'),
|
||||
name: t('modelManager.relatedModels'),
|
||||
getOptionCountString: (count) => t('common.model_withCount', { count }),
|
||||
options: relatedModels,
|
||||
});
|
||||
_options.push(relatedGroup);
|
||||
}
|
||||
|
||||
// Add other groups in the original order
|
||||
for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) {
|
||||
const group = groups[groupId];
|
||||
if (group) {
|
||||
@@ -180,7 +218,7 @@ export const ModelPicker = typedMemo(
|
||||
_options.push(...Object.values(groups));
|
||||
|
||||
return _options;
|
||||
}, [grouped, modelConfigs, t]);
|
||||
}, [grouped, modelConfigs, relatedModelKeys, t]);
|
||||
const popover = useDisclosure(false);
|
||||
const pickerRef = useRef<PickerContextState<T>>(null);
|
||||
|
||||
@@ -207,6 +245,18 @@ export const ModelPicker = typedMemo(
|
||||
return undefined;
|
||||
}, [allowEmpty, isInvalid, selectedModelConfig]);
|
||||
|
||||
// Create a component wrapper that includes related model styling
|
||||
const RelatedModelPickerOptionComponent = useCallback(
|
||||
({ option, ...rest }: { option: T } & BoxProps) => (
|
||||
<PickerOptionComponent
|
||||
option={option}
|
||||
isRelated={relatedModelKeys.includes(option.key)}
|
||||
{...rest}
|
||||
/>
|
||||
),
|
||||
[relatedModelKeys]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
isOpen={popover.isOpen}
|
||||
@@ -240,7 +290,7 @@ export const ModelPicker = typedMemo(
|
||||
onSelect={onSelect}
|
||||
selectedOption={selectedModelConfig}
|
||||
isMatch={isMatch}
|
||||
OptionComponent={PickerOptionComponent}
|
||||
OptionComponent={RelatedModelPickerOptionComponent}
|
||||
noOptionsFallback={<NoOptionsFallback noOptionsText={noOptionsText} />}
|
||||
noMatchesFallback={t('modelManager.noMatchingModels')}
|
||||
NextToSearchBar={<NavigateToModelManagerButton />}
|
||||
@@ -291,17 +341,19 @@ const optionNameSx: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => {
|
||||
const PickerOptionComponent = typedMemo(({ option, isRelated = false, ...rest }: { option: AnyModelConfig; isRelated?: boolean } & BoxProps) => {
|
||||
const { $compactView } = usePickerContext<AnyModelConfig>();
|
||||
const compactView = useStore($compactView);
|
||||
|
||||
const displayName = isRelated ? `* ${option.name}` : option.name;
|
||||
|
||||
return (
|
||||
<Flex {...rest} sx={optionSx} data-is-compact={compactView}>
|
||||
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
|
||||
<Flex flexDir="column" gap={1} flex={1}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Text sx={optionNameSx} data-is-compact={compactView}>
|
||||
{option.name}
|
||||
{displayName}
|
||||
</Text>
|
||||
<Spacer />
|
||||
{option.file_size > 0 && (
|
||||
|
||||
Reference in New Issue
Block a user