mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 18:45:07 -05:00
refactor(ui): clean up related models impl for picker
This commit is contained in:
@@ -91,6 +91,10 @@ const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGro
|
||||
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
|
||||
};
|
||||
|
||||
export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
|
||||
return !(uniqueGroupKey in optionOrGroup);
|
||||
};
|
||||
|
||||
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
|
||||
const { getOptionId } = usePickerContext();
|
||||
return <Text fontWeight="bold">{getOptionId(option)}</Text>;
|
||||
|
||||
@@ -4,67 +4,29 @@ import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { GroupStatusMap } from 'common/components/Picker/Picker';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { ModelPicker } from 'features/parameters/components/ModelPicker';
|
||||
import { API_BASE_MODELS } from 'features/parameters/types/constants';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
|
||||
|
||||
const selectSelectedModelKeys = createSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
|
||||
const keys: string[] = [];
|
||||
const main = params.model;
|
||||
const vae = params.vae;
|
||||
const refiner = params.refinerModel;
|
||||
const controlnet = params.controlLora;
|
||||
|
||||
if (main) {
|
||||
keys.push(main.key);
|
||||
}
|
||||
if (vae) {
|
||||
keys.push(vae.key);
|
||||
}
|
||||
if (refiner) {
|
||||
keys.push(refiner.key);
|
||||
}
|
||||
if (controlnet) {
|
||||
keys.push(controlnet.key);
|
||||
}
|
||||
for (const { model } of loras.loras) {
|
||||
keys.push(model.key);
|
||||
}
|
||||
|
||||
return uniq(keys);
|
||||
});
|
||||
|
||||
const LoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const { t } = useTranslation();
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
const selectedKeys = useAppSelector(selectSelectedModelKeys);
|
||||
|
||||
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { relatedKeys: EMPTY_ARRAY };
|
||||
}
|
||||
return { relatedKeys: data };
|
||||
},
|
||||
});
|
||||
|
||||
const currentBaseModel = useAppSelector((state) => state.params.model?.base);
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
|
||||
// Filter to only show compatible LoRAs
|
||||
const compatibleLoRAs = useMemo(() => {
|
||||
if (!currentBaseModel) {
|
||||
return [];
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
return modelConfigs.filter((model) => model.base === currentBaseModel);
|
||||
}, [modelConfigs, currentBaseModel]);
|
||||
@@ -121,7 +83,6 @@ const LoRASelect = () => {
|
||||
modelConfigs={compatibleLoRAs}
|
||||
onChange={onChange}
|
||||
grouped={false}
|
||||
relatedModelKeys={relatedKeys}
|
||||
selectedModelConfig={undefined}
|
||||
allowEmpty
|
||||
placeholder={placeholder}
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-lib
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Icon,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
@@ -12,12 +13,17 @@ import {
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { Group, PickerContextState } from 'common/components/Picker/Picker';
|
||||
import { buildGroup, getRegex, Picker, usePickerContext } from 'common/components/Picker/Picker';
|
||||
import { buildGroup, getRegex, isOption, Picker, usePickerContext } from 'common/components/Picker/Picker';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
|
||||
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage';
|
||||
@@ -29,10 +35,39 @@ import { filesize } from 'filesize';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
|
||||
|
||||
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
|
||||
const keys: string[] = [];
|
||||
const main = params.model;
|
||||
const vae = params.vae;
|
||||
const refiner = params.refinerModel;
|
||||
const controlnet = params.controlLora;
|
||||
|
||||
if (main) {
|
||||
keys.push(main.key);
|
||||
}
|
||||
if (vae) {
|
||||
keys.push(vae.key);
|
||||
}
|
||||
if (refiner) {
|
||||
keys.push(refiner.key);
|
||||
}
|
||||
if (controlnet) {
|
||||
keys.push(controlnet.key);
|
||||
}
|
||||
for (const { model } of loras.loras) {
|
||||
keys.push(model.key);
|
||||
}
|
||||
|
||||
return uniq(keys);
|
||||
});
|
||||
|
||||
type WithStarred<T> = T & { starred?: boolean };
|
||||
|
||||
// Type for models with starred field
|
||||
const getOptionId = <T extends AnyModelConfig>(modelConfig: T & { starred?: boolean }) => modelConfig.key;
|
||||
const getOptionId = <T extends AnyModelConfig>(modelConfig: WithStarred<T>) => modelConfig.key;
|
||||
|
||||
const ModelManagerLink = memo((props: ButtonProps) => {
|
||||
const onClickGoToModelManager = useStore($onClickGoToModelManager);
|
||||
@@ -105,6 +140,15 @@ const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string
|
||||
return BASE_COLOR_MAP[modelConfig.base];
|
||||
};
|
||||
|
||||
const relatedModelKeysQueryOptions = {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { relatedModelKeys: EMPTY_ARRAY };
|
||||
}
|
||||
return { relatedModelKeys: data };
|
||||
},
|
||||
} satisfies Parameters<typeof useGetRelatedModelIdsBatchQuery>[1];
|
||||
|
||||
const popperModifiers = [
|
||||
{
|
||||
// Prevents the popover from "touching" the edges of the screen
|
||||
@@ -113,13 +157,17 @@ const popperModifiers = [
|
||||
},
|
||||
];
|
||||
|
||||
const removeStarred = <T,>(obj: WithStarred<T>): T => {
|
||||
const { starred: _, ...rest } = obj;
|
||||
return rest as T;
|
||||
};
|
||||
|
||||
export const ModelPicker = typedMemo(
|
||||
<T extends AnyModelConfig = AnyModelConfig>({
|
||||
modelConfigs,
|
||||
selectedModelConfig,
|
||||
onChange,
|
||||
grouped,
|
||||
relatedModelKeys = [],
|
||||
getIsOptionDisabled,
|
||||
placeholder,
|
||||
allowEmpty,
|
||||
@@ -133,7 +181,6 @@ export const ModelPicker = typedMemo(
|
||||
selectedModelConfig: T | undefined;
|
||||
onChange: (modelConfig: T) => void;
|
||||
grouped?: boolean;
|
||||
relatedModelKeys?: string[];
|
||||
getIsOptionDisabled?: (model: T) => boolean;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
@@ -144,7 +191,11 @@ export const ModelPicker = typedMemo(
|
||||
initialGroupStates?: Record<string, boolean>;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const options = useMemo<(T & { starred?: boolean })[] | Group<T & { starred?: boolean }>[]>(() => {
|
||||
const selectedKeys = useAppSelector(selectSelectedModelKeys);
|
||||
|
||||
const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, relatedModelKeysQueryOptions);
|
||||
|
||||
const options = useMemo<WithStarred<T>[] | Group<WithStarred<T>>[]>(() => {
|
||||
if (!grouped) {
|
||||
// Add starred field to model options and sort them
|
||||
const modelsWithStarred = modelConfigs.map((model) => ({
|
||||
@@ -165,13 +216,13 @@ export const ModelPicker = typedMemo(
|
||||
}
|
||||
|
||||
// When all groups are disabled, we show all models
|
||||
const groups: Record<string, Group<T & { starred?: boolean }>> = {};
|
||||
const groups: Record<string, Group<WithStarred<T>>> = {};
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
const groupId = getGroupIDFromModelConfig(modelConfig);
|
||||
let group = groups[groupId];
|
||||
if (!group) {
|
||||
group = buildGroup<T & { starred?: boolean }>({
|
||||
group = buildGroup<WithStarred<T>>({
|
||||
id: modelConfig.base,
|
||||
color: `${getGroupColorSchemeFromModelConfig(modelConfig)}.300`,
|
||||
shortName: getGroupShortNameFromModelConfig(modelConfig),
|
||||
@@ -191,7 +242,7 @@ export const ModelPicker = typedMemo(
|
||||
}
|
||||
}
|
||||
|
||||
const _options: Group<T & { starred?: boolean }>[] = [];
|
||||
const _options: Group<WithStarred<T>>[] = [];
|
||||
|
||||
// Add groups in the original order
|
||||
for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) {
|
||||
@@ -216,7 +267,15 @@ export const ModelPicker = typedMemo(
|
||||
return _options;
|
||||
}, [grouped, modelConfigs, relatedModelKeys, t]);
|
||||
const popover = useDisclosure(false);
|
||||
const pickerRef = useRef<PickerContextState<T & { starred?: boolean }>>(null);
|
||||
const pickerRef = useRef<PickerContextState<WithStarred<T>>>(null);
|
||||
|
||||
const selectedOption = useMemo<WithStarred<T> | undefined>(() => {
|
||||
if (!selectedModelConfig) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return options.filter(isOption).find((o) => o.key === selectedModelConfig.key);
|
||||
}, [options, selectedModelConfig]);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
popover.close();
|
||||
@@ -224,11 +283,10 @@ export const ModelPicker = typedMemo(
|
||||
}, [popover]);
|
||||
|
||||
const onSelect = useCallback(
|
||||
(model: T & { starred?: boolean }) => {
|
||||
(model: WithStarred<T>) => {
|
||||
onClose();
|
||||
// Remove the starred field before passing to onChange
|
||||
const { starred: _, ...modelWithoutStarred } = model;
|
||||
onChange(modelWithoutStarred as T);
|
||||
onChange(removeStarred(model));
|
||||
},
|
||||
[onChange, onClose]
|
||||
);
|
||||
@@ -268,17 +326,13 @@ export const ModelPicker = typedMemo(
|
||||
<Portal appendToParentPortal={false}>
|
||||
<PopoverContent p={0} w={400} h={400}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody p={0} w="full" h="full">
|
||||
<Picker<T & { starred?: boolean }>
|
||||
<PopoverBody p={0} w="full" h="full" borderWidth={1} borderColor="base.700" borderRadius="base">
|
||||
<Picker<WithStarred<T>>
|
||||
handleRef={pickerRef}
|
||||
optionsOrGroups={options}
|
||||
getOptionId={getOptionId<T>}
|
||||
onSelect={onSelect}
|
||||
selectedOption={
|
||||
selectedModelConfig
|
||||
? { ...selectedModelConfig, starred: relatedModelKeys.includes(selectedModelConfig.key) }
|
||||
: undefined
|
||||
}
|
||||
selectedOption={selectedOption}
|
||||
isMatch={isMatch<T>}
|
||||
OptionComponent={PickerOptionComponent<T>}
|
||||
noOptionsFallback={<NoOptionsFallback noOptionsText={noOptionsText} />}
|
||||
@@ -332,8 +386,8 @@ const optionNameSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const PickerOptionComponent = typedMemo(
|
||||
<T extends AnyModelConfig>({ option, ...rest }: { option: T & { starred?: boolean } } & BoxProps) => {
|
||||
const { $compactView } = usePickerContext<T & { starred?: boolean }>();
|
||||
<T extends AnyModelConfig>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
|
||||
const { $compactView } = usePickerContext<WithStarred<T>>();
|
||||
const compactView = useStore($compactView);
|
||||
|
||||
return (
|
||||
@@ -341,7 +395,7 @@ const PickerOptionComponent = typedMemo(
|
||||
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
|
||||
<Flex flexDir="column" gap={1} flex={1}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{option.starred && <PiLinkSimple color="yellow" size={16} />}
|
||||
{option.starred && <Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={4} />}
|
||||
<Text sx={optionNameSx} data-is-compact={compactView}>
|
||||
{option.name}
|
||||
</Text>
|
||||
@@ -371,7 +425,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
|
||||
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
|
||||
};
|
||||
|
||||
const isMatch = <T extends AnyModelConfig>(model: T & { starred?: boolean }, searchTerm: string) => {
|
||||
const isMatch = <T extends AnyModelConfig>(model: WithStarred<T>, searchTerm: string) => {
|
||||
const regex = getRegex(searchTerm);
|
||||
const bases = BASE_KEYWORDS[model.base] ?? [model.base];
|
||||
const testString =
|
||||
|
||||
Reference in New Issue
Block a user