feat(ui): model picker filter buttons

This commit is contained in:
psychedelicious
2025-04-17 22:21:16 +10:00
parent aeb3841a6f
commit 97d45ceaf2
2 changed files with 114 additions and 22 deletions

View File

@@ -78,7 +78,7 @@ export const DefaultNoMatchesFallback = typedMemo(({ label }: { label?: string }
});
DefaultNoMatchesFallback.displayName = 'DefaultNoMatchesFallback';
export type PickerProps<T extends object, U> = {
export type PickerProps<T extends object, U, C> = {
options: (T | Group<T>)[];
getOptionId: (option: T) => string;
isMatch: (option: T, searchTerm: string) => boolean;
@@ -96,9 +96,10 @@ export type PickerProps<T extends object, U> = {
} & BoxProps
>;
GroupComponent?: React.ComponentType<PropsWithChildren<{ group: Group<T, U> } & BoxProps>>;
ctx: C;
};
type PickerContextState<T extends object, U> = {
type PickerContextState<T extends object, U, C> = {
options: (T | Group<T>)[];
getOptionId: (option: T) => string;
isMatch: (option: T, searchTerm: string) => boolean;
@@ -110,11 +111,12 @@ type PickerContextState<T extends object, U> = {
noMatchesFallback: React.ReactNode;
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
GroupComponent: React.ComponentType<PropsWithChildren<{ group: Group<T, U> } & BoxProps>>;
ctx: C;
};
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const PickerContext = createContext<PickerContextState<any, any> | null>(null);
export const usePickerContext = <T extends object, U>(): PickerContextState<T, U> => {
const PickerContext = createContext<PickerContextState<any, any, any> | null>(null);
export const usePickerContext = <T extends object, U, C>(): PickerContextState<T, U, C> => {
const context = useContext(PickerContext);
assert(context !== null, 'usePickerContext must be used within a PickerProvider');
return context;
@@ -191,7 +193,7 @@ const flattenOptions = <T extends object>(options: (T | Group<T>)[]): T[] => {
return flattened;
};
export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>) => {
export const Picker = typedMemo(<T extends object, U = undefined, C = undefined>(props: PickerProps<T, U, C>) => {
const {
getOptionId,
options,
@@ -206,6 +208,7 @@ export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>)
noOptionsFallback = <DefaultNoOptionsFallback />,
OptionComponent = DefaultOptionComponent,
GroupComponent = DefaultGroupComponent,
ctx: ctxProp,
} = props;
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
getFirstOptionId(options, getOptionId)
@@ -372,7 +375,8 @@ export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>)
noMatchesFallback,
OptionComponent,
GroupComponent,
}) satisfies PickerContextState<T, U>,
ctx: ctxProp,
}) satisfies PickerContextState<T, U, C>,
[
options,
getOptionId,
@@ -385,6 +389,7 @@ export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>)
noMatchesFallback,
OptionComponent,
GroupComponent,
ctxProp,
]
);
@@ -429,7 +434,7 @@ const DefaultPickerSearchBarComponent = typedMemo(
DefaultPickerSearchBarComponent.displayName = 'DefaultPickerSearchBarComponent';
const PickerList = typedMemo(
<T extends object, U>({
<T extends object, U, C>({
items,
activeOptionId,
selectedItemId,
@@ -438,7 +443,7 @@ const PickerList = typedMemo(
activeOptionId: string | undefined;
selectedItemId: string | undefined;
}) => {
const { getOptionId, getIsDisabled } = usePickerContext<T, U>();
const { getOptionId, getIsDisabled } = usePickerContext<T, U, C>();
if (items.length === 0) {
return (
@@ -486,7 +491,7 @@ const PickerList = typedMemo(
PickerList.displayName = 'PickerList';
const PickerOptionGroup = typedMemo(
<T extends object, U>({
<T extends object, U, C>({
group,
activeOptionId,
selectedItemId,
@@ -495,7 +500,7 @@ const PickerOptionGroup = typedMemo(
activeOptionId: string | undefined;
selectedItemId: string | undefined;
}) => {
const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext<T, U>();
const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext<T, U, C>();
return (
<GroupComponent group={group}>
@@ -519,14 +524,14 @@ const PickerOptionGroup = typedMemo(
PickerOptionGroup.displayName = 'PickerOptionGroup';
const PickerOption = typedMemo(
<T extends object, U>(props: {
<T extends object, U, C>(props: {
id: string;
option: T;
isActive: boolean;
isSelected: boolean;
isDisabled: boolean;
}) => {
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T, U>();
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T, U, C>();
const { id, option, isActive, isDisabled, isSelected } = props;
const onPointerMove = useCallback(() => {
setActiveOptionId(id);

View File

@@ -1,5 +1,6 @@
import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
import {
Badge,
Box,
Button,
Expander,
@@ -23,7 +24,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
import { DefaultNoMatchesFallback, DefaultNoOptionsFallback, getRegex, Picker } from 'common/components/Picker/Picker';
import {
DefaultNoMatchesFallback,
DefaultNoOptionsFallback,
getRegex,
Picker,
usePickerContext,
} from 'common/components/Picker/Picker';
import { useDisclosure } from 'common/hooks/useBoolean';
import { fixedForwardRef } from 'common/util/fixedForwardRef';
import { typedMemo } from 'common/util/typedMemo';
@@ -48,8 +55,9 @@ import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/
import { selectActiveTab, selectCompactModelPicker } from 'features/ui/store/uiSelectors';
import { compactModelPickerToggled } from 'features/ui/store/uiSlice';
import { filesize } from 'filesize';
import { isEqual } from 'lodash-es';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsInLineVerticalBold, PiArrowsOutLineVerticalBold, PiCaretDownBold } from 'react-icons/pi';
import { useMainModels } from 'services/api/hooks/modelsByType';
@@ -131,9 +139,57 @@ type GroupData = {
description: string;
};
type BaseModelTypeFilters = { [key in BaseModelType]?: boolean };
type PickerExtraContext = {
toggleBaseModelTypeFilter: (baseModelType: BaseModelType) => void;
basesWithModels: BaseModelType[];
baseModelTypeFilters: BaseModelTypeFilters;
};
const MainModelPicker = memo(() => {
const { t } = useTranslation();
const [modelConfigs] = useMainModels();
const basesWithModels = useMemo(() => {
const bases: BaseModelType[] = [];
for (const modelConfig of modelConfigs) {
if (!bases.includes(modelConfig.base)) {
bases.push(modelConfig.base);
}
}
return bases;
}, [modelConfigs]);
const [baseModelTypeFilters, setBaseModelTypeFilters] = useState<BaseModelTypeFilters>({});
useEffect(() => {
const newFilters: BaseModelTypeFilters = {};
if (isEqual(Object.keys(baseModelTypeFilters), basesWithModels)) {
return;
}
for (const base of basesWithModels) {
if (newFilters[base] === undefined) {
newFilters[base] = true;
} else {
newFilters[base] = baseModelTypeFilters[base];
}
}
setBaseModelTypeFilters(newFilters);
}, [baseModelTypeFilters, basesWithModels]);
const toggleBaseModelTypeFilter = useCallback(
(baseModelType: BaseModelType) => {
setBaseModelTypeFilters((prev) => {
const newFilters: BaseModelTypeFilters = {};
for (const base of basesWithModels) {
newFilters[base] = baseModelType === base ? !prev[base] : prev[base];
}
return newFilters;
});
},
[basesWithModels]
);
const ctx = useMemo(
() => ({ toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters }),
[toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters]
);
const grouped = useMemo<Group<AnyModelConfig, GroupData>[]>(() => {
const groups: {
[base in BaseModelType]?: Group<AnyModelConfig, GroupData>;
@@ -141,7 +197,7 @@ const MainModelPicker = memo(() => {
for (const modelConfig of modelConfigs) {
let group = groups[modelConfig.base];
if (!group) {
if (!group && baseModelTypeFilters[modelConfig.base]) {
group = {
id: modelConfig.base,
data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
@@ -149,8 +205,9 @@ const MainModelPicker = memo(() => {
};
groups[modelConfig.base] = group;
}
group.options.push(modelConfig);
if (group) {
group.options.push(modelConfig);
}
}
const sortedGroups: Group<AnyModelConfig, GroupData>[] = [];
@@ -178,7 +235,7 @@ const MainModelPicker = memo(() => {
sortedGroups.push(...Object.values(groups));
return sortedGroups;
}, [modelConfigs]);
}, [baseModelTypeFilters, modelConfigs]);
const modelConfig = useSelectedModelConfig();
const popover = useDisclosure(false);
const pickerRef = useRef<ImperativeModelPickerHandle>(null);
@@ -222,7 +279,7 @@ const MainModelPicker = memo(() => {
<PopoverContent p={0} w={448} h={512}>
<PopoverArrow />
<PopoverBody p={0} w="full" h="full">
<Picker<AnyModelConfig, GroupData>
<Picker<AnyModelConfig, GroupData, PickerExtraContext>
handleRef={pickerRef}
options={grouped}
getOptionId={getOptionId}
@@ -234,6 +291,7 @@ const MainModelPicker = memo(() => {
SearchBarComponent={SearchBarComponent}
noOptionsFallback={<DefaultNoOptionsFallback label={t('modelManager.noModelsInstalled')} />}
noMatchesFallback={<DefaultNoMatchesFallback label={t('modelManager.noMatchingModels')} />}
ctx={ctx}
/>
</PopoverBody>
</PopoverContent>
@@ -248,12 +306,12 @@ const SearchBarComponent = typedMemo(
const { t } = useTranslation();
const dispatch = useAppDispatch();
const compactModelPicker = useAppSelector(selectCompactModelPicker);
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
const onToggleCompact = useCallback(() => {
dispatch(compactModelPickerToggled());
}, [dispatch]);
return (
<Flex flexDir="column" w="full">
<Flex flexDir="column" w="full" gap={2}>
<Flex gap={2} alignItems="center">
<Input ref={ref} {...props} placeholder={t('modelManager.filterModels')} />
<NavigateToModelManagerButton />
@@ -265,13 +323,42 @@ const SearchBarComponent = typedMemo(
onClick={onToggleCompact}
/>
</Flex>
<Flex gap={2} alignItems="center"></Flex>
<Flex gap={2} alignItems="center">
{ctx.basesWithModels.map((base) => (
<ModelBaseFilterButton key={base} base={base} />
))}
</Flex>
</Flex>
);
})
);
SearchBarComponent.displayName = 'SearchBarComponent';
const ModelBaseFilterButton = memo(({ base }: { base: BaseModelType }) => {
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
const onClick = useCallback(() => {
ctx.toggleBaseModelTypeFilter(base);
}, [base, ctx]);
return (
<Badge
role="button"
size="xs"
variant="solid"
userSelect="none"
bg={ctx.baseModelTypeFilters[base] ? `${BASE_COLOR_MAP[base]}.300` : 'transparent'}
color={ctx.baseModelTypeFilters[base] ? undefined : 'base.200'}
borderColor={`${BASE_COLOR_MAP[base]}.300`}
borderWidth={1}
onClick={onClick}
>
{MODEL_TYPE_SHORT_MAP[base]}
</Badge>
);
});
ModelBaseFilterButton.displayName = 'ModelBaseFilterButton';
const PickerGroupComponent = memo(
({ group, children }: PropsWithChildren<{ group: Group<AnyModelConfig, GroupData> }>) => {
return (