feat(ui): model picker filter buttons

This commit is contained in:
psychedelicious
2025-04-17 22:21:16 +10:00
parent aeb3841a6f
commit 97d45ceaf2
2 changed files with 114 additions and 22 deletions

View File

@@ -1,5 +1,6 @@
import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
import {
Badge,
Box,
Button,
Expander,
@@ -23,7 +24,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
import { DefaultNoMatchesFallback, DefaultNoOptionsFallback, getRegex, Picker } from 'common/components/Picker/Picker';
import {
DefaultNoMatchesFallback,
DefaultNoOptionsFallback,
getRegex,
Picker,
usePickerContext,
} from 'common/components/Picker/Picker';
import { useDisclosure } from 'common/hooks/useBoolean';
import { fixedForwardRef } from 'common/util/fixedForwardRef';
import { typedMemo } from 'common/util/typedMemo';
@@ -48,8 +55,9 @@ import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/
import { selectActiveTab, selectCompactModelPicker } from 'features/ui/store/uiSelectors';
import { compactModelPickerToggled } from 'features/ui/store/uiSlice';
import { filesize } from 'filesize';
import { isEqual } from 'lodash-es';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsInLineVerticalBold, PiArrowsOutLineVerticalBold, PiCaretDownBold } from 'react-icons/pi';
import { useMainModels } from 'services/api/hooks/modelsByType';
@@ -131,9 +139,57 @@ type GroupData = {
description: string;
};
type BaseModelTypeFilters = { [key in BaseModelType]?: boolean };
type PickerExtraContext = {
toggleBaseModelTypeFilter: (baseModelType: BaseModelType) => void;
basesWithModels: BaseModelType[];
baseModelTypeFilters: BaseModelTypeFilters;
};
const MainModelPicker = memo(() => {
const { t } = useTranslation();
const [modelConfigs] = useMainModels();
const basesWithModels = useMemo(() => {
const bases: BaseModelType[] = [];
for (const modelConfig of modelConfigs) {
if (!bases.includes(modelConfig.base)) {
bases.push(modelConfig.base);
}
}
return bases;
}, [modelConfigs]);
const [baseModelTypeFilters, setBaseModelTypeFilters] = useState<BaseModelTypeFilters>({});
useEffect(() => {
const newFilters: BaseModelTypeFilters = {};
if (isEqual(Object.keys(baseModelTypeFilters), basesWithModels)) {
return;
}
for (const base of basesWithModels) {
if (newFilters[base] === undefined) {
newFilters[base] = true;
} else {
newFilters[base] = baseModelTypeFilters[base];
}
}
setBaseModelTypeFilters(newFilters);
}, [baseModelTypeFilters, basesWithModels]);
const toggleBaseModelTypeFilter = useCallback(
(baseModelType: BaseModelType) => {
setBaseModelTypeFilters((prev) => {
const newFilters: BaseModelTypeFilters = {};
for (const base of basesWithModels) {
newFilters[base] = baseModelType === base ? !prev[base] : prev[base];
}
return newFilters;
});
},
[basesWithModels]
);
const ctx = useMemo(
() => ({ toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters }),
[toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters]
);
const grouped = useMemo<Group<AnyModelConfig, GroupData>[]>(() => {
const groups: {
[base in BaseModelType]?: Group<AnyModelConfig, GroupData>;
@@ -141,7 +197,7 @@ const MainModelPicker = memo(() => {
for (const modelConfig of modelConfigs) {
let group = groups[modelConfig.base];
if (!group) {
if (!group && baseModelTypeFilters[modelConfig.base]) {
group = {
id: modelConfig.base,
data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
@@ -149,8 +205,9 @@ const MainModelPicker = memo(() => {
};
groups[modelConfig.base] = group;
}
group.options.push(modelConfig);
if (group) {
group.options.push(modelConfig);
}
}
const sortedGroups: Group<AnyModelConfig, GroupData>[] = [];
@@ -178,7 +235,7 @@ const MainModelPicker = memo(() => {
sortedGroups.push(...Object.values(groups));
return sortedGroups;
}, [modelConfigs]);
}, [baseModelTypeFilters, modelConfigs]);
const modelConfig = useSelectedModelConfig();
const popover = useDisclosure(false);
const pickerRef = useRef<ImperativeModelPickerHandle>(null);
@@ -222,7 +279,7 @@ const MainModelPicker = memo(() => {
<PopoverContent p={0} w={448} h={512}>
<PopoverArrow />
<PopoverBody p={0} w="full" h="full">
<Picker<AnyModelConfig, GroupData>
<Picker<AnyModelConfig, GroupData, PickerExtraContext>
handleRef={pickerRef}
options={grouped}
getOptionId={getOptionId}
@@ -234,6 +291,7 @@ const MainModelPicker = memo(() => {
SearchBarComponent={SearchBarComponent}
noOptionsFallback={<DefaultNoOptionsFallback label={t('modelManager.noModelsInstalled')} />}
noMatchesFallback={<DefaultNoMatchesFallback label={t('modelManager.noMatchingModels')} />}
ctx={ctx}
/>
</PopoverBody>
</PopoverContent>
@@ -248,12 +306,12 @@ const SearchBarComponent = typedMemo(
const { t } = useTranslation();
const dispatch = useAppDispatch();
const compactModelPicker = useAppSelector(selectCompactModelPicker);
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
const onToggleCompact = useCallback(() => {
dispatch(compactModelPickerToggled());
}, [dispatch]);
return (
<Flex flexDir="column" w="full">
<Flex flexDir="column" w="full" gap={2}>
<Flex gap={2} alignItems="center">
<Input ref={ref} {...props} placeholder={t('modelManager.filterModels')} />
<NavigateToModelManagerButton />
@@ -265,13 +323,42 @@ const SearchBarComponent = typedMemo(
onClick={onToggleCompact}
/>
</Flex>
<Flex gap={2} alignItems="center"></Flex>
<Flex gap={2} alignItems="center">
{ctx.basesWithModels.map((base) => (
<ModelBaseFilterButton key={base} base={base} />
))}
</Flex>
</Flex>
);
})
);
SearchBarComponent.displayName = 'SearchBarComponent';
const ModelBaseFilterButton = memo(({ base }: { base: BaseModelType }) => {
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
const onClick = useCallback(() => {
ctx.toggleBaseModelTypeFilter(base);
}, [base, ctx]);
return (
<Badge
role="button"
size="xs"
variant="solid"
userSelect="none"
bg={ctx.baseModelTypeFilters[base] ? `${BASE_COLOR_MAP[base]}.300` : 'transparent'}
color={ctx.baseModelTypeFilters[base] ? undefined : 'base.200'}
borderColor={`${BASE_COLOR_MAP[base]}.300`}
borderWidth={1}
onClick={onClick}
>
{MODEL_TYPE_SHORT_MAP[base]}
</Badge>
);
});
ModelBaseFilterButton.displayName = 'ModelBaseFilterButton';
const PickerGroupComponent = memo(
({ group, children }: PropsWithChildren<{ group: Group<AnyModelConfig, GroupData> }>) => {
return (