mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 06:15:59 -05:00
feat(ui): model picker filter buttons
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Badge,
|
||||
Box,
|
||||
Button,
|
||||
Expander,
|
||||
@@ -23,7 +24,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
|
||||
import { DefaultNoMatchesFallback, DefaultNoOptionsFallback, getRegex, Picker } from 'common/components/Picker/Picker';
|
||||
import {
|
||||
DefaultNoMatchesFallback,
|
||||
DefaultNoOptionsFallback,
|
||||
getRegex,
|
||||
Picker,
|
||||
usePickerContext,
|
||||
} from 'common/components/Picker/Picker';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import { fixedForwardRef } from 'common/util/fixedForwardRef';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
@@ -48,8 +55,9 @@ import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/
|
||||
import { selectActiveTab, selectCompactModelPicker } from 'features/ui/store/uiSelectors';
|
||||
import { compactModelPickerToggled } from 'features/ui/store/uiSlice';
|
||||
import { filesize } from 'filesize';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsInLineVerticalBold, PiArrowsOutLineVerticalBold, PiCaretDownBold } from 'react-icons/pi';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -131,9 +139,57 @@ type GroupData = {
|
||||
description: string;
|
||||
};
|
||||
|
||||
type BaseModelTypeFilters = { [key in BaseModelType]?: boolean };
|
||||
|
||||
type PickerExtraContext = {
|
||||
toggleBaseModelTypeFilter: (baseModelType: BaseModelType) => void;
|
||||
basesWithModels: BaseModelType[];
|
||||
baseModelTypeFilters: BaseModelTypeFilters;
|
||||
};
|
||||
|
||||
const MainModelPicker = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [modelConfigs] = useMainModels();
|
||||
const basesWithModels = useMemo(() => {
|
||||
const bases: BaseModelType[] = [];
|
||||
for (const modelConfig of modelConfigs) {
|
||||
if (!bases.includes(modelConfig.base)) {
|
||||
bases.push(modelConfig.base);
|
||||
}
|
||||
}
|
||||
return bases;
|
||||
}, [modelConfigs]);
|
||||
const [baseModelTypeFilters, setBaseModelTypeFilters] = useState<BaseModelTypeFilters>({});
|
||||
useEffect(() => {
|
||||
const newFilters: BaseModelTypeFilters = {};
|
||||
if (isEqual(Object.keys(baseModelTypeFilters), basesWithModels)) {
|
||||
return;
|
||||
}
|
||||
for (const base of basesWithModels) {
|
||||
if (newFilters[base] === undefined) {
|
||||
newFilters[base] = true;
|
||||
} else {
|
||||
newFilters[base] = baseModelTypeFilters[base];
|
||||
}
|
||||
}
|
||||
setBaseModelTypeFilters(newFilters);
|
||||
}, [baseModelTypeFilters, basesWithModels]);
|
||||
const toggleBaseModelTypeFilter = useCallback(
|
||||
(baseModelType: BaseModelType) => {
|
||||
setBaseModelTypeFilters((prev) => {
|
||||
const newFilters: BaseModelTypeFilters = {};
|
||||
for (const base of basesWithModels) {
|
||||
newFilters[base] = baseModelType === base ? !prev[base] : prev[base];
|
||||
}
|
||||
return newFilters;
|
||||
});
|
||||
},
|
||||
[basesWithModels]
|
||||
);
|
||||
const ctx = useMemo(
|
||||
() => ({ toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters }),
|
||||
[toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters]
|
||||
);
|
||||
const grouped = useMemo<Group<AnyModelConfig, GroupData>[]>(() => {
|
||||
const groups: {
|
||||
[base in BaseModelType]?: Group<AnyModelConfig, GroupData>;
|
||||
@@ -141,7 +197,7 @@ const MainModelPicker = memo(() => {
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
let group = groups[modelConfig.base];
|
||||
if (!group) {
|
||||
if (!group && baseModelTypeFilters[modelConfig.base]) {
|
||||
group = {
|
||||
id: modelConfig.base,
|
||||
data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
|
||||
@@ -149,8 +205,9 @@ const MainModelPicker = memo(() => {
|
||||
};
|
||||
groups[modelConfig.base] = group;
|
||||
}
|
||||
|
||||
group.options.push(modelConfig);
|
||||
if (group) {
|
||||
group.options.push(modelConfig);
|
||||
}
|
||||
}
|
||||
|
||||
const sortedGroups: Group<AnyModelConfig, GroupData>[] = [];
|
||||
@@ -178,7 +235,7 @@ const MainModelPicker = memo(() => {
|
||||
sortedGroups.push(...Object.values(groups));
|
||||
|
||||
return sortedGroups;
|
||||
}, [modelConfigs]);
|
||||
}, [baseModelTypeFilters, modelConfigs]);
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const popover = useDisclosure(false);
|
||||
const pickerRef = useRef<ImperativeModelPickerHandle>(null);
|
||||
@@ -222,7 +279,7 @@ const MainModelPicker = memo(() => {
|
||||
<PopoverContent p={0} w={448} h={512}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody p={0} w="full" h="full">
|
||||
<Picker<AnyModelConfig, GroupData>
|
||||
<Picker<AnyModelConfig, GroupData, PickerExtraContext>
|
||||
handleRef={pickerRef}
|
||||
options={grouped}
|
||||
getOptionId={getOptionId}
|
||||
@@ -234,6 +291,7 @@ const MainModelPicker = memo(() => {
|
||||
SearchBarComponent={SearchBarComponent}
|
||||
noOptionsFallback={<DefaultNoOptionsFallback label={t('modelManager.noModelsInstalled')} />}
|
||||
noMatchesFallback={<DefaultNoMatchesFallback label={t('modelManager.noMatchingModels')} />}
|
||||
ctx={ctx}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
@@ -248,12 +306,12 @@ const SearchBarComponent = typedMemo(
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const compactModelPicker = useAppSelector(selectCompactModelPicker);
|
||||
|
||||
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
|
||||
const onToggleCompact = useCallback(() => {
|
||||
dispatch(compactModelPickerToggled());
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex flexDir="column" w="full" gap={2}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Input ref={ref} {...props} placeholder={t('modelManager.filterModels')} />
|
||||
<NavigateToModelManagerButton />
|
||||
@@ -265,13 +323,42 @@ const SearchBarComponent = typedMemo(
|
||||
onClick={onToggleCompact}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} alignItems="center"></Flex>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{ctx.basesWithModels.map((base) => (
|
||||
<ModelBaseFilterButton key={base} base={base} />
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
})
|
||||
);
|
||||
SearchBarComponent.displayName = 'SearchBarComponent';
|
||||
|
||||
const ModelBaseFilterButton = memo(({ base }: { base: BaseModelType }) => {
|
||||
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
ctx.toggleBaseModelTypeFilter(base);
|
||||
}, [base, ctx]);
|
||||
|
||||
return (
|
||||
<Badge
|
||||
role="button"
|
||||
size="xs"
|
||||
variant="solid"
|
||||
userSelect="none"
|
||||
bg={ctx.baseModelTypeFilters[base] ? `${BASE_COLOR_MAP[base]}.300` : 'transparent'}
|
||||
color={ctx.baseModelTypeFilters[base] ? undefined : 'base.200'}
|
||||
borderColor={`${BASE_COLOR_MAP[base]}.300`}
|
||||
borderWidth={1}
|
||||
onClick={onClick}
|
||||
>
|
||||
{MODEL_TYPE_SHORT_MAP[base]}
|
||||
</Badge>
|
||||
);
|
||||
});
|
||||
ModelBaseFilterButton.displayName = 'ModelBaseFilterButton';
|
||||
|
||||
const PickerGroupComponent = memo(
|
||||
({ group, children }: PropsWithChildren<{ group: Group<AnyModelConfig, GroupData> }>) => {
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user