feat(ui): iterate on model combobox (wip)

This commit is contained in:
psychedelicious
2025-04-15 19:46:19 +10:00
parent 87aeb7f889
commit aa7c5c281a
2 changed files with 103 additions and 104 deletions

View File

@@ -8,17 +8,16 @@ import { NavigateToModelManagerButton } from 'features/parameters/components/Mai
import type { ChangeEvent } from 'react';
import { useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type OptionGroup<T extends object, U = any> = {
export type Group<T extends object, U = any> = {
id: string;
data: U;
options: T[];
};
const isGroup = <T extends object>(item: T | OptionGroup<T>): item is OptionGroup<T> => {
return item ? 'options' in item && Array.isArray(item.options) : false;
const isGroup = <T extends object>(option: T | Group<T>): option is Group<T> => {
return option ? 'options' in option && Array.isArray(option.options) : false;
};
export type ImperativeModelPickerHandle = {
@@ -28,7 +27,7 @@ export type ImperativeModelPickerHandle = {
setSearchTerm: (searchTerm: string) => void;
};
const DefaultItemComponent = ({ id }: { id: string }) => {
const DefaultOptionComponent = ({ id }: { id: string }) => {
return <Text fontWeight="bold">{id}</Text>;
};
@@ -37,21 +36,21 @@ const DefaultGroupHeaderComponent = ({ id }: { id: string }) => {
};
export type PickerProps<T extends object> = {
options: (T | OptionGroup<T>)[];
getId: (item: T) => string;
isMatch: (item: T, searchTerm: string) => boolean;
getIsDisabled?: (item: T) => boolean;
options: (T | Group<T>)[];
getOptionId: (option: T) => string;
isMatch: (option: T, searchTerm: string) => boolean;
getIsDisabled?: (option: T) => boolean;
selectedItem?: T;
onSelect?: (item: T) => void;
onSelect?: (option: T) => void;
onClose?: () => void;
noOptionsFallback?: React.ReactNode;
noMatchesFallback?: React.ReactNode;
handleRef?: React.Ref<ImperativeModelPickerHandle>;
ItemComponent?: React.ComponentType<{ item: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: OptionGroup<T> }>;
OptionComponent?: React.ComponentType<{ option: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
};
const getRegex = (searchTerm: string) =>
export const getRegex = (searchTerm: string) =>
new RegExp(
searchTerm
.trim()
@@ -61,33 +60,7 @@ const getRegex = (searchTerm: string) =>
'gi'
);
const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],
'sd-2': ['sd2', 'sd2.0', 'sd2.1', 'sd-2'],
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
};
export const isMatch = (model: AnyModelConfig, searchTerm: string) => {
const regex = getRegex(searchTerm);
if (
model.name.toLowerCase().includes(searchTerm) ||
regex.test(model.name) ||
(BASE_KEYWORDS[model.base] ?? [model.base]).some((kw) => kw.toLowerCase().includes(searchTerm) || regex.test(kw)) ||
model.type.toLowerCase().includes(searchTerm) ||
regex.test(model.type) ||
(model.description ?? '').toLowerCase().includes(searchTerm) ||
regex.test(model.description ?? '') ||
model.format.toLowerCase().includes(searchTerm) ||
regex.test(model.format)
) {
return true;
}
return false;
};
const getFirstOption = <T extends object>(options: (T | OptionGroup<T>)[]): T | undefined => {
const getFirstOption = <T extends object>(options: (T | Group<T>)[]): T | undefined => {
const firstOptionOrGroup = options[0];
if (!firstOptionOrGroup) {
return;
@@ -100,37 +73,37 @@ const getFirstOption = <T extends object>(options: (T | OptionGroup<T>)[]): T |
};
const getFirstOptionId = <T extends object>(
options: (T | OptionGroup<T>)[],
getId: (item: T) => string
options: (T | Group<T>)[],
getOptionId: (item: T) => string
): string | undefined => {
const firstOptionOrGroup = getFirstOption(options);
if (firstOptionOrGroup) {
return getId(firstOptionOrGroup);
return getOptionId(firstOptionOrGroup);
} else {
return undefined;
}
};
const findOption = <T extends object>(
options: (T | OptionGroup<T>)[],
options: (T | Group<T>)[],
id: string,
getId: (item: T) => string
getOptionId: (item: T) => string
): T | undefined => {
for (const optionOrGroup of options) {
if (isGroup(optionOrGroup)) {
const option = optionOrGroup.options.find((opt) => getId(opt) === id);
const option = optionOrGroup.options.find((opt) => getOptionId(opt) === id);
if (option) {
return option;
}
} else {
if (getId(optionOrGroup) === id) {
if (getOptionId(optionOrGroup) === id) {
return optionOrGroup;
}
}
}
};
const flattenOptions = <T extends object>(options: (T | OptionGroup<T>)[]): T[] => {
const flattenOptions = <T extends object>(options: (T | Group<T>)[]): T[] => {
const flattened: T[] = [];
for (const optionOrGroup of options) {
if (isGroup(optionOrGroup)) {
@@ -145,7 +118,7 @@ const flattenOptions = <T extends object>(options: (T | OptionGroup<T>)[]): T[]
export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
const { t } = useTranslation();
const {
getId,
getOptionId,
options,
handleRef,
isMatch,
@@ -155,15 +128,15 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
onClose,
onSelect,
selectedItem,
ItemComponent,
OptionComponent,
GroupHeaderComponent,
} = props;
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
getFirstOptionId(options, getId)
getFirstOptionId(options, getOptionId)
);
const rootRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const [filteredOptions, setFilteredOptions] = useState<(T | OptionGroup<T>)[]>(options);
const [filteredOptions, setFilteredOptions] = useState<(T | Group<T>)[]>(options);
const flattenedOptions = useMemo(() => flattenOptions(options), [options]);
const flattenedFilteredOptions = useMemo(() => flattenOptions(filteredOptions), [filteredOptions]);
const [searchTerm, setSearchTerm] = useState('');
@@ -176,10 +149,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
useEffect(() => {
if (!searchTerm) {
setFilteredOptions(options);
setActiveOptionId(getFirstOptionId(options, getId));
setActiveOptionId(getFirstOptionId(options, getOptionId));
} else {
const lowercasedSearchTerm = searchTerm.toLowerCase();
const filtered: (T | OptionGroup<T>)[] = [];
const filtered: (T | Group<T>)[] = [];
for (const item of props.options) {
if (isGroup(item)) {
const filteredItems = item.options.filter(
@@ -195,20 +168,20 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
}
}
setFilteredOptions(filtered);
setActiveOptionId(getFirstOptionId(filtered, getId));
setActiveOptionId(getFirstOptionId(filtered, getOptionId));
}
}, [searchTerm, setActiveOptionId, props.options, options, getId, isMatch, getIsDisabled]);
}, [searchTerm, setActiveOptionId, props.options, options, getOptionId, isMatch, getIsDisabled]);
const onSelectInternal = useCallback(
(id: string) => {
const item = findOption(options, id, getId);
const item = findOption(options, id, getOptionId);
if (!item) {
// Model not found? We should never get here.
return;
}
onSelect?.(item);
},
[getId, options, onSelect]
[getOptionId, options, onSelect]
);
const setValueAndScrollIntoView = useCallback(
@@ -237,11 +210,11 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
if (e.metaKey) {
const item = flattenedFilteredOptions.at(0);
if (item) {
setValueAndScrollIntoView(getId(item));
setValueAndScrollIntoView(getOptionId(item));
}
return;
}
const currentIndex = flattenedFilteredOptions.findIndex((item) => getId(item) === activeOptionId);
const currentIndex = flattenedFilteredOptions.findIndex((item) => getOptionId(item) === activeOptionId);
if (currentIndex < 0) {
return;
}
@@ -251,10 +224,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
}
const item = flattenedFilteredOptions.at(newIndex);
if (item) {
setValueAndScrollIntoView(getId(item));
setValueAndScrollIntoView(getOptionId(item));
}
},
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getId]
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getOptionId]
);
const next = useCallback(
@@ -267,12 +240,12 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
if (e.metaKey) {
const item = flattenedFilteredOptions.at(-1);
if (item) {
setValueAndScrollIntoView(getId(item));
setValueAndScrollIntoView(getOptionId(item));
}
return;
}
const currentIndex = flattenedFilteredOptions.findIndex((item) => getId(item) === activeOptionId);
const currentIndex = flattenedFilteredOptions.findIndex((item) => getOptionId(item) === activeOptionId);
if (currentIndex < 0) {
return;
}
@@ -282,10 +255,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
}
const item = flattenedFilteredOptions.at(newIndex);
if (item) {
setValueAndScrollIntoView(getId(item));
setValueAndScrollIntoView(getOptionId(item));
}
},
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getId]
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getOptionId]
);
const onKeyDown = useCallback(
@@ -296,7 +269,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
next(e);
} else if (e.key === 'Enter') {
const activeOptionId = getActiveOptionId();
const item = flattenedFilteredOptions.find((item) => getId(item) === activeOptionId);
const item = flattenedFilteredOptions.find((item) => getOptionId(item) === activeOptionId);
if (!item) {
// Model not found? We should never get here.
return;
@@ -310,7 +283,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
inputRef.current?.select();
}
},
[prev, next, getActiveOptionId, flattenedFilteredOptions, onSelect, getId, onClose]
[prev, next, getActiveOptionId, flattenedFilteredOptions, onSelect, getOptionId, onClose]
);
return (
@@ -337,13 +310,13 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
<ScrollableContent>
<PickerList
items={filteredOptions}
getId={getId}
getOptionId={getOptionId}
activeOptionId={activeOptionId}
setActiveOptionId={setActiveOptionId}
selectedItemId={selectedItem ? getId(selectedItem) : undefined}
selectedItemId={selectedItem ? getOptionId(selectedItem) : undefined}
onSelect={onSelectInternal}
getIsDisabled={getIsDisabled}
ItemComponent={ItemComponent}
OptionComponent={OptionComponent}
GroupHeaderComponent={GroupHeaderComponent}
/>
</ScrollableContent>
@@ -361,20 +334,20 @@ const PickerList = typedMemo(
setActiveOptionId,
selectedItemId,
onSelect,
getId,
getOptionId,
getIsDisabled,
ItemComponent,
OptionComponent,
GroupHeaderComponent,
}: {
items: (T | OptionGroup<T>)[];
items: (T | Group<T>)[];
activeOptionId: string | undefined;
setActiveOptionId: (key: string) => void;
selectedItemId: string | undefined;
onSelect: (key: string) => void;
getId: (item: T) => string;
getIsDisabled?: (item: T) => boolean;
ItemComponent?: React.ComponentType<{ item: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: OptionGroup<T> }>;
getOptionId: (option: T) => string;
getIsDisabled?: (option: T) => boolean;
OptionComponent?: React.ComponentType<{ option: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
}) => {
if (items.length === 0) {
return (
@@ -399,27 +372,27 @@ const PickerList = typedMemo(
group={itemOrGroup}
setActiveOptionId={setActiveOptionId}
activeOptionId={activeOptionId}
getId={getId}
getOptionId={getOptionId}
onSelect={onSelect}
selectedItemId={selectedItemId}
ItemComponent={ItemComponent}
OptionComponent={OptionComponent}
getIsDisabled={getIsDisabled}
GroupHeaderComponent={GroupHeaderComponent}
/>
);
} else {
const id = getId(itemOrGroup);
const id = getOptionId(itemOrGroup);
return (
<PickerOption
key={id}
id={id}
item={itemOrGroup}
option={itemOrGroup}
setActiveOptionId={setActiveOptionId}
onSelect={onSelect}
isActive={id === activeOptionId}
isSelected={id === selectedItemId}
isDisabled={getIsDisabled?.(itemOrGroup) ?? false}
ItemComponent={ItemComponent}
OptionComponent={OptionComponent}
/>
);
}
@@ -433,42 +406,42 @@ PickerList.displayName = 'PickerList';
const PickerOptionGroup = typedMemo(
<T extends object>({
group,
getId,
getOptionId,
setActiveOptionId,
onSelect,
activeOptionId,
selectedItemId,
getIsDisabled,
ItemComponent,
OptionComponent,
GroupHeaderComponent,
}: {
group: OptionGroup<T>;
getId: (item: T) => string;
group: Group<T>;
getOptionId: (option: T) => string;
setActiveOptionId: (key: string) => void;
onSelect: (key: string) => void;
activeOptionId: string | undefined;
selectedItemId: string | undefined;
getIsDisabled?: (item: T) => boolean;
ItemComponent?: React.ComponentType<{ item: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: OptionGroup<T> }>;
getIsDisabled?: (option: T) => boolean;
OptionComponent?: React.ComponentType<{ option: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
}) => {
return (
<Flex key={group.id} flexDir="column" gap={2} w="full">
{GroupHeaderComponent ? <GroupHeaderComponent group={group} /> : <DefaultGroupHeaderComponent id={group.id} />}
<Flex flexDir="column" gap={2} w="full">
{group.options.map((item) => {
const id = getId(item);
const id = getOptionId(item);
return (
<PickerOption
key={id}
id={id}
item={item}
option={item}
setActiveOptionId={setActiveOptionId}
onSelect={onSelect}
isActive={id === activeOptionId}
isSelected={id === selectedItemId}
isDisabled={getIsDisabled?.(item) ?? false}
ItemComponent={ItemComponent}
OptionComponent={OptionComponent}
/>
);
})}
@@ -501,15 +474,15 @@ const itemSx: SystemStyleObject = {
const PickerOption = typedMemo(
<T extends object>(props: {
id: string;
item: T;
option: T;
setActiveOptionId: (key: string) => void;
onSelect: (key: string) => void;
isActive: boolean;
isSelected: boolean;
isDisabled: boolean;
ItemComponent?: React.ComponentType<{ item: T }>;
OptionComponent?: React.ComponentType<{ option: T }>;
}) => {
const { id, item, ItemComponent, setActiveOptionId, onSelect, isActive, isDisabled, isSelected } = props;
const { id, option, OptionComponent, setActiveOptionId, onSelect, isActive, isDisabled, isSelected } = props;
const onPointerMove = useCallback(() => {
setActiveOptionId(id);
}, [id, setActiveOptionId]);
@@ -527,7 +500,7 @@ const PickerOption = typedMemo(
onPointerMove={isDisabled ? undefined : onPointerMove}
onClick={isDisabled ? undefined : onClick}
>
{ItemComponent ? <ItemComponent item={item} /> : <DefaultItemComponent id={id} />}
{OptionComponent ? <OptionComponent option={option} /> : <DefaultOptionComponent id={id} />}
</Box>
);
}

View File

@@ -19,8 +19,8 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { ImperativeModelPickerHandle, OptionGroup } from 'common/components/Picker/Picker';
import { isMatch, Picker } from 'common/components/Picker/Picker';
import type { Group,ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
import { getRegex, Picker } from 'common/components/Picker/Picker';
import { useDisclosure } from 'common/hooks/useBoolean';
import { typedMemo } from 'common/util/typedMemo';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
@@ -127,8 +127,8 @@ const getIsDisabled = (modelConfig: AnyModelConfig) => {
const MainModelPicker = memo(() => {
const { t } = useTranslation();
const [modelConfigs] = useMainModels();
const grouped = useMemo<OptionGroup<AnyModelConfig, { name: string; description: string }>[]>(() => {
const groups: { [base in BaseModelType]?: OptionGroup<AnyModelConfig, { name: string; description: string }> } = {};
const grouped = useMemo<Group<AnyModelConfig, { name: string; description: string }>[]>(() => {
const groups: { [base in BaseModelType]?: Group<AnyModelConfig, { name: string; description: string }> } = {};
for (const modelConfig of modelConfigs) {
let group = groups[modelConfig.base];
@@ -144,7 +144,7 @@ const MainModelPicker = memo(() => {
group.options.push(modelConfig);
}
const sortedGroups: OptionGroup<AnyModelConfig, { name: string; description: string }>[] = [];
const sortedGroups: Group<AnyModelConfig, { name: string; description: string }>[] = [];
if (groups['flux']) {
sortedGroups.push(groups['flux']);
@@ -220,7 +220,7 @@ const MainModelPicker = memo(() => {
selectedItem={modelConfig}
getIsDisabled={getIsDisabled}
isMatch={isMatch}
ItemComponent={PickerItemComponent}
OptionComponent={PickerItemComponent}
GroupHeaderComponent={PickerGroupHeaderComponent}
noOptionsFallback={<Text>{t('common.noOptions')}</Text>}
noMatchesFallback={<Text>{t('common.noMatches')}</Text>}
@@ -233,7 +233,7 @@ const MainModelPicker = memo(() => {
MainModelPicker.displayName = 'MainModelPicker';
const PickerGroupHeaderComponent = memo(
({ group }: { group: OptionGroup<AnyModelConfig, { name: string; description: string }> }) => {
({ group }: { group: Group<AnyModelConfig, { name: string; description: string }> }) => {
return (
<Flex flexDir="column">
<Text fontSize="sm" fontWeight="semibold">
@@ -269,3 +269,29 @@ export const PickerItemComponent = typedMemo(({ item }: { item: AnyModelConfig }
);
});
PickerItemComponent.displayName = 'PickerItemComponent';
const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],
'sd-2': ['sd2', 'sd2.0', 'sd2.1', 'sd-2'],
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
};
const isMatch = (model: AnyModelConfig, searchTerm: string) => {
const regex = getRegex(searchTerm);
if (
model.name.toLowerCase().includes(searchTerm) ||
regex.test(model.name) ||
(BASE_KEYWORDS[model.base] ?? [model.base]).some((kw) => kw.toLowerCase().includes(searchTerm) || regex.test(kw)) ||
model.type.toLowerCase().includes(searchTerm) ||
regex.test(model.type) ||
(model.description ?? '').toLowerCase().includes(searchTerm) ||
regex.test(model.description ?? '') ||
model.format.toLowerCase().includes(searchTerm) ||
regex.test(model.format)
) {
return true;
}
return false;
};