mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
feat(ui): model picker filter buttons
This commit is contained in:
@@ -78,7 +78,7 @@ export const DefaultNoMatchesFallback = typedMemo(({ label }: { label?: string }
|
||||
});
|
||||
DefaultNoMatchesFallback.displayName = 'DefaultNoMatchesFallback';
|
||||
|
||||
export type PickerProps<T extends object, U> = {
|
||||
export type PickerProps<T extends object, U, C> = {
|
||||
options: (T | Group<T>)[];
|
||||
getOptionId: (option: T) => string;
|
||||
isMatch: (option: T, searchTerm: string) => boolean;
|
||||
@@ -96,9 +96,10 @@ export type PickerProps<T extends object, U> = {
|
||||
} & BoxProps
|
||||
>;
|
||||
GroupComponent?: React.ComponentType<PropsWithChildren<{ group: Group<T, U> } & BoxProps>>;
|
||||
ctx: C;
|
||||
};
|
||||
|
||||
type PickerContextState<T extends object, U> = {
|
||||
type PickerContextState<T extends object, U, C> = {
|
||||
options: (T | Group<T>)[];
|
||||
getOptionId: (option: T) => string;
|
||||
isMatch: (option: T, searchTerm: string) => boolean;
|
||||
@@ -110,11 +111,12 @@ type PickerContextState<T extends object, U> = {
|
||||
noMatchesFallback: React.ReactNode;
|
||||
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
|
||||
GroupComponent: React.ComponentType<PropsWithChildren<{ group: Group<T, U> } & BoxProps>>;
|
||||
ctx: C;
|
||||
};
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const PickerContext = createContext<PickerContextState<any, any> | null>(null);
|
||||
export const usePickerContext = <T extends object, U>(): PickerContextState<T, U> => {
|
||||
const PickerContext = createContext<PickerContextState<any, any, any> | null>(null);
|
||||
export const usePickerContext = <T extends object, U, C>(): PickerContextState<T, U, C> => {
|
||||
const context = useContext(PickerContext);
|
||||
assert(context !== null, 'usePickerContext must be used within a PickerProvider');
|
||||
return context;
|
||||
@@ -191,7 +193,7 @@ const flattenOptions = <T extends object>(options: (T | Group<T>)[]): T[] => {
|
||||
return flattened;
|
||||
};
|
||||
|
||||
export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>) => {
|
||||
export const Picker = typedMemo(<T extends object, U = undefined, C = undefined>(props: PickerProps<T, U, C>) => {
|
||||
const {
|
||||
getOptionId,
|
||||
options,
|
||||
@@ -206,6 +208,7 @@ export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>)
|
||||
noOptionsFallback = <DefaultNoOptionsFallback />,
|
||||
OptionComponent = DefaultOptionComponent,
|
||||
GroupComponent = DefaultGroupComponent,
|
||||
ctx: ctxProp,
|
||||
} = props;
|
||||
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
|
||||
getFirstOptionId(options, getOptionId)
|
||||
@@ -372,7 +375,8 @@ export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>)
|
||||
noMatchesFallback,
|
||||
OptionComponent,
|
||||
GroupComponent,
|
||||
}) satisfies PickerContextState<T, U>,
|
||||
ctx: ctxProp,
|
||||
}) satisfies PickerContextState<T, U, C>,
|
||||
[
|
||||
options,
|
||||
getOptionId,
|
||||
@@ -385,6 +389,7 @@ export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>)
|
||||
noMatchesFallback,
|
||||
OptionComponent,
|
||||
GroupComponent,
|
||||
ctxProp,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -429,7 +434,7 @@ const DefaultPickerSearchBarComponent = typedMemo(
|
||||
DefaultPickerSearchBarComponent.displayName = 'DefaultPickerSearchBarComponent';
|
||||
|
||||
const PickerList = typedMemo(
|
||||
<T extends object, U>({
|
||||
<T extends object, U, C>({
|
||||
items,
|
||||
activeOptionId,
|
||||
selectedItemId,
|
||||
@@ -438,7 +443,7 @@ const PickerList = typedMemo(
|
||||
activeOptionId: string | undefined;
|
||||
selectedItemId: string | undefined;
|
||||
}) => {
|
||||
const { getOptionId, getIsDisabled } = usePickerContext<T, U>();
|
||||
const { getOptionId, getIsDisabled } = usePickerContext<T, U, C>();
|
||||
|
||||
if (items.length === 0) {
|
||||
return (
|
||||
@@ -486,7 +491,7 @@ const PickerList = typedMemo(
|
||||
PickerList.displayName = 'PickerList';
|
||||
|
||||
const PickerOptionGroup = typedMemo(
|
||||
<T extends object, U>({
|
||||
<T extends object, U, C>({
|
||||
group,
|
||||
activeOptionId,
|
||||
selectedItemId,
|
||||
@@ -495,7 +500,7 @@ const PickerOptionGroup = typedMemo(
|
||||
activeOptionId: string | undefined;
|
||||
selectedItemId: string | undefined;
|
||||
}) => {
|
||||
const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext<T, U>();
|
||||
const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext<T, U, C>();
|
||||
|
||||
return (
|
||||
<GroupComponent group={group}>
|
||||
@@ -519,14 +524,14 @@ const PickerOptionGroup = typedMemo(
|
||||
PickerOptionGroup.displayName = 'PickerOptionGroup';
|
||||
|
||||
const PickerOption = typedMemo(
|
||||
<T extends object, U>(props: {
|
||||
<T extends object, U, C>(props: {
|
||||
id: string;
|
||||
option: T;
|
||||
isActive: boolean;
|
||||
isSelected: boolean;
|
||||
isDisabled: boolean;
|
||||
}) => {
|
||||
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T, U>();
|
||||
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T, U, C>();
|
||||
const { id, option, isActive, isDisabled, isSelected } = props;
|
||||
const onPointerMove = useCallback(() => {
|
||||
setActiveOptionId(id);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Badge,
|
||||
Box,
|
||||
Button,
|
||||
Expander,
|
||||
@@ -23,7 +24,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
|
||||
import { DefaultNoMatchesFallback, DefaultNoOptionsFallback, getRegex, Picker } from 'common/components/Picker/Picker';
|
||||
import {
|
||||
DefaultNoMatchesFallback,
|
||||
DefaultNoOptionsFallback,
|
||||
getRegex,
|
||||
Picker,
|
||||
usePickerContext,
|
||||
} from 'common/components/Picker/Picker';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import { fixedForwardRef } from 'common/util/fixedForwardRef';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
@@ -48,8 +55,9 @@ import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/
|
||||
import { selectActiveTab, selectCompactModelPicker } from 'features/ui/store/uiSelectors';
|
||||
import { compactModelPickerToggled } from 'features/ui/store/uiSlice';
|
||||
import { filesize } from 'filesize';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsInLineVerticalBold, PiArrowsOutLineVerticalBold, PiCaretDownBold } from 'react-icons/pi';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -131,9 +139,57 @@ type GroupData = {
|
||||
description: string;
|
||||
};
|
||||
|
||||
type BaseModelTypeFilters = { [key in BaseModelType]?: boolean };
|
||||
|
||||
type PickerExtraContext = {
|
||||
toggleBaseModelTypeFilter: (baseModelType: BaseModelType) => void;
|
||||
basesWithModels: BaseModelType[];
|
||||
baseModelTypeFilters: BaseModelTypeFilters;
|
||||
};
|
||||
|
||||
const MainModelPicker = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [modelConfigs] = useMainModels();
|
||||
const basesWithModels = useMemo(() => {
|
||||
const bases: BaseModelType[] = [];
|
||||
for (const modelConfig of modelConfigs) {
|
||||
if (!bases.includes(modelConfig.base)) {
|
||||
bases.push(modelConfig.base);
|
||||
}
|
||||
}
|
||||
return bases;
|
||||
}, [modelConfigs]);
|
||||
const [baseModelTypeFilters, setBaseModelTypeFilters] = useState<BaseModelTypeFilters>({});
|
||||
useEffect(() => {
|
||||
const newFilters: BaseModelTypeFilters = {};
|
||||
if (isEqual(Object.keys(baseModelTypeFilters), basesWithModels)) {
|
||||
return;
|
||||
}
|
||||
for (const base of basesWithModels) {
|
||||
if (newFilters[base] === undefined) {
|
||||
newFilters[base] = true;
|
||||
} else {
|
||||
newFilters[base] = baseModelTypeFilters[base];
|
||||
}
|
||||
}
|
||||
setBaseModelTypeFilters(newFilters);
|
||||
}, [baseModelTypeFilters, basesWithModels]);
|
||||
const toggleBaseModelTypeFilter = useCallback(
|
||||
(baseModelType: BaseModelType) => {
|
||||
setBaseModelTypeFilters((prev) => {
|
||||
const newFilters: BaseModelTypeFilters = {};
|
||||
for (const base of basesWithModels) {
|
||||
newFilters[base] = baseModelType === base ? !prev[base] : prev[base];
|
||||
}
|
||||
return newFilters;
|
||||
});
|
||||
},
|
||||
[basesWithModels]
|
||||
);
|
||||
const ctx = useMemo(
|
||||
() => ({ toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters }),
|
||||
[toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters]
|
||||
);
|
||||
const grouped = useMemo<Group<AnyModelConfig, GroupData>[]>(() => {
|
||||
const groups: {
|
||||
[base in BaseModelType]?: Group<AnyModelConfig, GroupData>;
|
||||
@@ -141,7 +197,7 @@ const MainModelPicker = memo(() => {
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
let group = groups[modelConfig.base];
|
||||
if (!group) {
|
||||
if (!group && baseModelTypeFilters[modelConfig.base]) {
|
||||
group = {
|
||||
id: modelConfig.base,
|
||||
data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
|
||||
@@ -149,8 +205,9 @@ const MainModelPicker = memo(() => {
|
||||
};
|
||||
groups[modelConfig.base] = group;
|
||||
}
|
||||
|
||||
group.options.push(modelConfig);
|
||||
if (group) {
|
||||
group.options.push(modelConfig);
|
||||
}
|
||||
}
|
||||
|
||||
const sortedGroups: Group<AnyModelConfig, GroupData>[] = [];
|
||||
@@ -178,7 +235,7 @@ const MainModelPicker = memo(() => {
|
||||
sortedGroups.push(...Object.values(groups));
|
||||
|
||||
return sortedGroups;
|
||||
}, [modelConfigs]);
|
||||
}, [baseModelTypeFilters, modelConfigs]);
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const popover = useDisclosure(false);
|
||||
const pickerRef = useRef<ImperativeModelPickerHandle>(null);
|
||||
@@ -222,7 +279,7 @@ const MainModelPicker = memo(() => {
|
||||
<PopoverContent p={0} w={448} h={512}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody p={0} w="full" h="full">
|
||||
<Picker<AnyModelConfig, GroupData>
|
||||
<Picker<AnyModelConfig, GroupData, PickerExtraContext>
|
||||
handleRef={pickerRef}
|
||||
options={grouped}
|
||||
getOptionId={getOptionId}
|
||||
@@ -234,6 +291,7 @@ const MainModelPicker = memo(() => {
|
||||
SearchBarComponent={SearchBarComponent}
|
||||
noOptionsFallback={<DefaultNoOptionsFallback label={t('modelManager.noModelsInstalled')} />}
|
||||
noMatchesFallback={<DefaultNoMatchesFallback label={t('modelManager.noMatchingModels')} />}
|
||||
ctx={ctx}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
@@ -248,12 +306,12 @@ const SearchBarComponent = typedMemo(
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const compactModelPicker = useAppSelector(selectCompactModelPicker);
|
||||
|
||||
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
|
||||
const onToggleCompact = useCallback(() => {
|
||||
dispatch(compactModelPickerToggled());
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex flexDir="column" w="full" gap={2}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Input ref={ref} {...props} placeholder={t('modelManager.filterModels')} />
|
||||
<NavigateToModelManagerButton />
|
||||
@@ -265,13 +323,42 @@ const SearchBarComponent = typedMemo(
|
||||
onClick={onToggleCompact}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} alignItems="center"></Flex>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{ctx.basesWithModels.map((base) => (
|
||||
<ModelBaseFilterButton key={base} base={base} />
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
})
|
||||
);
|
||||
SearchBarComponent.displayName = 'SearchBarComponent';
|
||||
|
||||
const ModelBaseFilterButton = memo(({ base }: { base: BaseModelType }) => {
|
||||
const { ctx } = usePickerContext<AnyModelConfig, GroupData, PickerExtraContext>();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
ctx.toggleBaseModelTypeFilter(base);
|
||||
}, [base, ctx]);
|
||||
|
||||
return (
|
||||
<Badge
|
||||
role="button"
|
||||
size="xs"
|
||||
variant="solid"
|
||||
userSelect="none"
|
||||
bg={ctx.baseModelTypeFilters[base] ? `${BASE_COLOR_MAP[base]}.300` : 'transparent'}
|
||||
color={ctx.baseModelTypeFilters[base] ? undefined : 'base.200'}
|
||||
borderColor={`${BASE_COLOR_MAP[base]}.300`}
|
||||
borderWidth={1}
|
||||
onClick={onClick}
|
||||
>
|
||||
{MODEL_TYPE_SHORT_MAP[base]}
|
||||
</Badge>
|
||||
);
|
||||
});
|
||||
ModelBaseFilterButton.displayName = 'ModelBaseFilterButton';
|
||||
|
||||
const PickerGroupComponent = memo(
|
||||
({ group, children }: PropsWithChildren<{ group: Group<AnyModelConfig, GroupData> }>) => {
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user