mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): iterate on model combobox (wip)
This commit is contained in:
@@ -8,17 +8,16 @@ import { NavigateToModelManagerButton } from 'features/parameters/components/Mai
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export type OptionGroup<T extends object, U = any> = {
|
||||
export type Group<T extends object, U = any> = {
|
||||
id: string;
|
||||
data: U;
|
||||
options: T[];
|
||||
};
|
||||
|
||||
const isGroup = <T extends object>(item: T | OptionGroup<T>): item is OptionGroup<T> => {
|
||||
return item ? 'options' in item && Array.isArray(item.options) : false;
|
||||
const isGroup = <T extends object>(option: T | Group<T>): option is Group<T> => {
|
||||
return option ? 'options' in option && Array.isArray(option.options) : false;
|
||||
};
|
||||
|
||||
export type ImperativeModelPickerHandle = {
|
||||
@@ -28,7 +27,7 @@ export type ImperativeModelPickerHandle = {
|
||||
setSearchTerm: (searchTerm: string) => void;
|
||||
};
|
||||
|
||||
const DefaultItemComponent = ({ id }: { id: string }) => {
|
||||
const DefaultOptionComponent = ({ id }: { id: string }) => {
|
||||
return <Text fontWeight="bold">{id}</Text>;
|
||||
};
|
||||
|
||||
@@ -37,21 +36,21 @@ const DefaultGroupHeaderComponent = ({ id }: { id: string }) => {
|
||||
};
|
||||
|
||||
export type PickerProps<T extends object> = {
|
||||
options: (T | OptionGroup<T>)[];
|
||||
getId: (item: T) => string;
|
||||
isMatch: (item: T, searchTerm: string) => boolean;
|
||||
getIsDisabled?: (item: T) => boolean;
|
||||
options: (T | Group<T>)[];
|
||||
getOptionId: (option: T) => string;
|
||||
isMatch: (option: T, searchTerm: string) => boolean;
|
||||
getIsDisabled?: (option: T) => boolean;
|
||||
selectedItem?: T;
|
||||
onSelect?: (item: T) => void;
|
||||
onSelect?: (option: T) => void;
|
||||
onClose?: () => void;
|
||||
noOptionsFallback?: React.ReactNode;
|
||||
noMatchesFallback?: React.ReactNode;
|
||||
handleRef?: React.Ref<ImperativeModelPickerHandle>;
|
||||
ItemComponent?: React.ComponentType<{ item: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: OptionGroup<T> }>;
|
||||
OptionComponent?: React.ComponentType<{ option: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
|
||||
};
|
||||
|
||||
const getRegex = (searchTerm: string) =>
|
||||
export const getRegex = (searchTerm: string) =>
|
||||
new RegExp(
|
||||
searchTerm
|
||||
.trim()
|
||||
@@ -61,33 +60,7 @@ const getRegex = (searchTerm: string) =>
|
||||
'gi'
|
||||
);
|
||||
|
||||
const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
|
||||
'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],
|
||||
'sd-2': ['sd2', 'sd2.0', 'sd2.1', 'sd-2'],
|
||||
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
|
||||
};
|
||||
|
||||
export const isMatch = (model: AnyModelConfig, searchTerm: string) => {
|
||||
const regex = getRegex(searchTerm);
|
||||
|
||||
if (
|
||||
model.name.toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.name) ||
|
||||
(BASE_KEYWORDS[model.base] ?? [model.base]).some((kw) => kw.toLowerCase().includes(searchTerm) || regex.test(kw)) ||
|
||||
model.type.toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.type) ||
|
||||
(model.description ?? '').toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.description ?? '') ||
|
||||
model.format.toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.format)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
const getFirstOption = <T extends object>(options: (T | OptionGroup<T>)[]): T | undefined => {
|
||||
const getFirstOption = <T extends object>(options: (T | Group<T>)[]): T | undefined => {
|
||||
const firstOptionOrGroup = options[0];
|
||||
if (!firstOptionOrGroup) {
|
||||
return;
|
||||
@@ -100,37 +73,37 @@ const getFirstOption = <T extends object>(options: (T | OptionGroup<T>)[]): T |
|
||||
};
|
||||
|
||||
const getFirstOptionId = <T extends object>(
|
||||
options: (T | OptionGroup<T>)[],
|
||||
getId: (item: T) => string
|
||||
options: (T | Group<T>)[],
|
||||
getOptionId: (item: T) => string
|
||||
): string | undefined => {
|
||||
const firstOptionOrGroup = getFirstOption(options);
|
||||
if (firstOptionOrGroup) {
|
||||
return getId(firstOptionOrGroup);
|
||||
return getOptionId(firstOptionOrGroup);
|
||||
} else {
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const findOption = <T extends object>(
|
||||
options: (T | OptionGroup<T>)[],
|
||||
options: (T | Group<T>)[],
|
||||
id: string,
|
||||
getId: (item: T) => string
|
||||
getOptionId: (item: T) => string
|
||||
): T | undefined => {
|
||||
for (const optionOrGroup of options) {
|
||||
if (isGroup(optionOrGroup)) {
|
||||
const option = optionOrGroup.options.find((opt) => getId(opt) === id);
|
||||
const option = optionOrGroup.options.find((opt) => getOptionId(opt) === id);
|
||||
if (option) {
|
||||
return option;
|
||||
}
|
||||
} else {
|
||||
if (getId(optionOrGroup) === id) {
|
||||
if (getOptionId(optionOrGroup) === id) {
|
||||
return optionOrGroup;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const flattenOptions = <T extends object>(options: (T | OptionGroup<T>)[]): T[] => {
|
||||
const flattenOptions = <T extends object>(options: (T | Group<T>)[]): T[] => {
|
||||
const flattened: T[] = [];
|
||||
for (const optionOrGroup of options) {
|
||||
if (isGroup(optionOrGroup)) {
|
||||
@@ -145,7 +118,7 @@ const flattenOptions = <T extends object>(options: (T | OptionGroup<T>)[]): T[]
|
||||
export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
const { t } = useTranslation();
|
||||
const {
|
||||
getId,
|
||||
getOptionId,
|
||||
options,
|
||||
handleRef,
|
||||
isMatch,
|
||||
@@ -155,15 +128,15 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
onClose,
|
||||
onSelect,
|
||||
selectedItem,
|
||||
ItemComponent,
|
||||
OptionComponent,
|
||||
GroupHeaderComponent,
|
||||
} = props;
|
||||
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
|
||||
getFirstOptionId(options, getId)
|
||||
getFirstOptionId(options, getOptionId)
|
||||
);
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [filteredOptions, setFilteredOptions] = useState<(T | OptionGroup<T>)[]>(options);
|
||||
const [filteredOptions, setFilteredOptions] = useState<(T | Group<T>)[]>(options);
|
||||
const flattenedOptions = useMemo(() => flattenOptions(options), [options]);
|
||||
const flattenedFilteredOptions = useMemo(() => flattenOptions(filteredOptions), [filteredOptions]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
@@ -176,10 +149,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
useEffect(() => {
|
||||
if (!searchTerm) {
|
||||
setFilteredOptions(options);
|
||||
setActiveOptionId(getFirstOptionId(options, getId));
|
||||
setActiveOptionId(getFirstOptionId(options, getOptionId));
|
||||
} else {
|
||||
const lowercasedSearchTerm = searchTerm.toLowerCase();
|
||||
const filtered: (T | OptionGroup<T>)[] = [];
|
||||
const filtered: (T | Group<T>)[] = [];
|
||||
for (const item of props.options) {
|
||||
if (isGroup(item)) {
|
||||
const filteredItems = item.options.filter(
|
||||
@@ -195,20 +168,20 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
}
|
||||
}
|
||||
setFilteredOptions(filtered);
|
||||
setActiveOptionId(getFirstOptionId(filtered, getId));
|
||||
setActiveOptionId(getFirstOptionId(filtered, getOptionId));
|
||||
}
|
||||
}, [searchTerm, setActiveOptionId, props.options, options, getId, isMatch, getIsDisabled]);
|
||||
}, [searchTerm, setActiveOptionId, props.options, options, getOptionId, isMatch, getIsDisabled]);
|
||||
|
||||
const onSelectInternal = useCallback(
|
||||
(id: string) => {
|
||||
const item = findOption(options, id, getId);
|
||||
const item = findOption(options, id, getOptionId);
|
||||
if (!item) {
|
||||
// Model not found? We should never get here.
|
||||
return;
|
||||
}
|
||||
onSelect?.(item);
|
||||
},
|
||||
[getId, options, onSelect]
|
||||
[getOptionId, options, onSelect]
|
||||
);
|
||||
|
||||
const setValueAndScrollIntoView = useCallback(
|
||||
@@ -237,11 +210,11 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
if (e.metaKey) {
|
||||
const item = flattenedFilteredOptions.at(0);
|
||||
if (item) {
|
||||
setValueAndScrollIntoView(getId(item));
|
||||
setValueAndScrollIntoView(getOptionId(item));
|
||||
}
|
||||
return;
|
||||
}
|
||||
const currentIndex = flattenedFilteredOptions.findIndex((item) => getId(item) === activeOptionId);
|
||||
const currentIndex = flattenedFilteredOptions.findIndex((item) => getOptionId(item) === activeOptionId);
|
||||
if (currentIndex < 0) {
|
||||
return;
|
||||
}
|
||||
@@ -251,10 +224,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
}
|
||||
const item = flattenedFilteredOptions.at(newIndex);
|
||||
if (item) {
|
||||
setValueAndScrollIntoView(getId(item));
|
||||
setValueAndScrollIntoView(getOptionId(item));
|
||||
}
|
||||
},
|
||||
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getId]
|
||||
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getOptionId]
|
||||
);
|
||||
|
||||
const next = useCallback(
|
||||
@@ -267,12 +240,12 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
if (e.metaKey) {
|
||||
const item = flattenedFilteredOptions.at(-1);
|
||||
if (item) {
|
||||
setValueAndScrollIntoView(getId(item));
|
||||
setValueAndScrollIntoView(getOptionId(item));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = flattenedFilteredOptions.findIndex((item) => getId(item) === activeOptionId);
|
||||
const currentIndex = flattenedFilteredOptions.findIndex((item) => getOptionId(item) === activeOptionId);
|
||||
if (currentIndex < 0) {
|
||||
return;
|
||||
}
|
||||
@@ -282,10 +255,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
}
|
||||
const item = flattenedFilteredOptions.at(newIndex);
|
||||
if (item) {
|
||||
setValueAndScrollIntoView(getId(item));
|
||||
setValueAndScrollIntoView(getOptionId(item));
|
||||
}
|
||||
},
|
||||
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getId]
|
||||
[getActiveOptionId, flattenedFilteredOptions, setValueAndScrollIntoView, getOptionId]
|
||||
);
|
||||
|
||||
const onKeyDown = useCallback(
|
||||
@@ -296,7 +269,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
next(e);
|
||||
} else if (e.key === 'Enter') {
|
||||
const activeOptionId = getActiveOptionId();
|
||||
const item = flattenedFilteredOptions.find((item) => getId(item) === activeOptionId);
|
||||
const item = flattenedFilteredOptions.find((item) => getOptionId(item) === activeOptionId);
|
||||
if (!item) {
|
||||
// Model not found? We should never get here.
|
||||
return;
|
||||
@@ -310,7 +283,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
inputRef.current?.select();
|
||||
}
|
||||
},
|
||||
[prev, next, getActiveOptionId, flattenedFilteredOptions, onSelect, getId, onClose]
|
||||
[prev, next, getActiveOptionId, flattenedFilteredOptions, onSelect, getOptionId, onClose]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -337,13 +310,13 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
<ScrollableContent>
|
||||
<PickerList
|
||||
items={filteredOptions}
|
||||
getId={getId}
|
||||
getOptionId={getOptionId}
|
||||
activeOptionId={activeOptionId}
|
||||
setActiveOptionId={setActiveOptionId}
|
||||
selectedItemId={selectedItem ? getId(selectedItem) : undefined}
|
||||
selectedItemId={selectedItem ? getOptionId(selectedItem) : undefined}
|
||||
onSelect={onSelectInternal}
|
||||
getIsDisabled={getIsDisabled}
|
||||
ItemComponent={ItemComponent}
|
||||
OptionComponent={OptionComponent}
|
||||
GroupHeaderComponent={GroupHeaderComponent}
|
||||
/>
|
||||
</ScrollableContent>
|
||||
@@ -361,20 +334,20 @@ const PickerList = typedMemo(
|
||||
setActiveOptionId,
|
||||
selectedItemId,
|
||||
onSelect,
|
||||
getId,
|
||||
getOptionId,
|
||||
getIsDisabled,
|
||||
ItemComponent,
|
||||
OptionComponent,
|
||||
GroupHeaderComponent,
|
||||
}: {
|
||||
items: (T | OptionGroup<T>)[];
|
||||
items: (T | Group<T>)[];
|
||||
activeOptionId: string | undefined;
|
||||
setActiveOptionId: (key: string) => void;
|
||||
selectedItemId: string | undefined;
|
||||
onSelect: (key: string) => void;
|
||||
getId: (item: T) => string;
|
||||
getIsDisabled?: (item: T) => boolean;
|
||||
ItemComponent?: React.ComponentType<{ item: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: OptionGroup<T> }>;
|
||||
getOptionId: (option: T) => string;
|
||||
getIsDisabled?: (option: T) => boolean;
|
||||
OptionComponent?: React.ComponentType<{ option: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
|
||||
}) => {
|
||||
if (items.length === 0) {
|
||||
return (
|
||||
@@ -399,27 +372,27 @@ const PickerList = typedMemo(
|
||||
group={itemOrGroup}
|
||||
setActiveOptionId={setActiveOptionId}
|
||||
activeOptionId={activeOptionId}
|
||||
getId={getId}
|
||||
getOptionId={getOptionId}
|
||||
onSelect={onSelect}
|
||||
selectedItemId={selectedItemId}
|
||||
ItemComponent={ItemComponent}
|
||||
OptionComponent={OptionComponent}
|
||||
getIsDisabled={getIsDisabled}
|
||||
GroupHeaderComponent={GroupHeaderComponent}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
const id = getId(itemOrGroup);
|
||||
const id = getOptionId(itemOrGroup);
|
||||
return (
|
||||
<PickerOption
|
||||
key={id}
|
||||
id={id}
|
||||
item={itemOrGroup}
|
||||
option={itemOrGroup}
|
||||
setActiveOptionId={setActiveOptionId}
|
||||
onSelect={onSelect}
|
||||
isActive={id === activeOptionId}
|
||||
isSelected={id === selectedItemId}
|
||||
isDisabled={getIsDisabled?.(itemOrGroup) ?? false}
|
||||
ItemComponent={ItemComponent}
|
||||
OptionComponent={OptionComponent}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -433,42 +406,42 @@ PickerList.displayName = 'PickerList';
|
||||
const PickerOptionGroup = typedMemo(
|
||||
<T extends object>({
|
||||
group,
|
||||
getId,
|
||||
getOptionId,
|
||||
setActiveOptionId,
|
||||
onSelect,
|
||||
activeOptionId,
|
||||
selectedItemId,
|
||||
getIsDisabled,
|
||||
ItemComponent,
|
||||
OptionComponent,
|
||||
GroupHeaderComponent,
|
||||
}: {
|
||||
group: OptionGroup<T>;
|
||||
getId: (item: T) => string;
|
||||
group: Group<T>;
|
||||
getOptionId: (option: T) => string;
|
||||
setActiveOptionId: (key: string) => void;
|
||||
onSelect: (key: string) => void;
|
||||
activeOptionId: string | undefined;
|
||||
selectedItemId: string | undefined;
|
||||
getIsDisabled?: (item: T) => boolean;
|
||||
ItemComponent?: React.ComponentType<{ item: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: OptionGroup<T> }>;
|
||||
getIsDisabled?: (option: T) => boolean;
|
||||
OptionComponent?: React.ComponentType<{ option: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
|
||||
}) => {
|
||||
return (
|
||||
<Flex key={group.id} flexDir="column" gap={2} w="full">
|
||||
{GroupHeaderComponent ? <GroupHeaderComponent group={group} /> : <DefaultGroupHeaderComponent id={group.id} />}
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
{group.options.map((item) => {
|
||||
const id = getId(item);
|
||||
const id = getOptionId(item);
|
||||
return (
|
||||
<PickerOption
|
||||
key={id}
|
||||
id={id}
|
||||
item={item}
|
||||
option={item}
|
||||
setActiveOptionId={setActiveOptionId}
|
||||
onSelect={onSelect}
|
||||
isActive={id === activeOptionId}
|
||||
isSelected={id === selectedItemId}
|
||||
isDisabled={getIsDisabled?.(item) ?? false}
|
||||
ItemComponent={ItemComponent}
|
||||
OptionComponent={OptionComponent}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
@@ -501,15 +474,15 @@ const itemSx: SystemStyleObject = {
|
||||
const PickerOption = typedMemo(
|
||||
<T extends object>(props: {
|
||||
id: string;
|
||||
item: T;
|
||||
option: T;
|
||||
setActiveOptionId: (key: string) => void;
|
||||
onSelect: (key: string) => void;
|
||||
isActive: boolean;
|
||||
isSelected: boolean;
|
||||
isDisabled: boolean;
|
||||
ItemComponent?: React.ComponentType<{ item: T }>;
|
||||
OptionComponent?: React.ComponentType<{ option: T }>;
|
||||
}) => {
|
||||
const { id, item, ItemComponent, setActiveOptionId, onSelect, isActive, isDisabled, isSelected } = props;
|
||||
const { id, option, OptionComponent, setActiveOptionId, onSelect, isActive, isDisabled, isSelected } = props;
|
||||
const onPointerMove = useCallback(() => {
|
||||
setActiveOptionId(id);
|
||||
}, [id, setActiveOptionId]);
|
||||
@@ -527,7 +500,7 @@ const PickerOption = typedMemo(
|
||||
onPointerMove={isDisabled ? undefined : onPointerMove}
|
||||
onClick={isDisabled ? undefined : onClick}
|
||||
>
|
||||
{ItemComponent ? <ItemComponent item={item} /> : <DefaultItemComponent id={id} />}
|
||||
{OptionComponent ? <OptionComponent option={option} /> : <DefaultOptionComponent id={id} />}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { ImperativeModelPickerHandle, OptionGroup } from 'common/components/Picker/Picker';
|
||||
import { isMatch, Picker } from 'common/components/Picker/Picker';
|
||||
import type { Group,ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
|
||||
import { getRegex, Picker } from 'common/components/Picker/Picker';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
@@ -127,8 +127,8 @@ const getIsDisabled = (modelConfig: AnyModelConfig) => {
|
||||
const MainModelPicker = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [modelConfigs] = useMainModels();
|
||||
const grouped = useMemo<OptionGroup<AnyModelConfig, { name: string; description: string }>[]>(() => {
|
||||
const groups: { [base in BaseModelType]?: OptionGroup<AnyModelConfig, { name: string; description: string }> } = {};
|
||||
const grouped = useMemo<Group<AnyModelConfig, { name: string; description: string }>[]>(() => {
|
||||
const groups: { [base in BaseModelType]?: Group<AnyModelConfig, { name: string; description: string }> } = {};
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
let group = groups[modelConfig.base];
|
||||
@@ -144,7 +144,7 @@ const MainModelPicker = memo(() => {
|
||||
group.options.push(modelConfig);
|
||||
}
|
||||
|
||||
const sortedGroups: OptionGroup<AnyModelConfig, { name: string; description: string }>[] = [];
|
||||
const sortedGroups: Group<AnyModelConfig, { name: string; description: string }>[] = [];
|
||||
|
||||
if (groups['flux']) {
|
||||
sortedGroups.push(groups['flux']);
|
||||
@@ -220,7 +220,7 @@ const MainModelPicker = memo(() => {
|
||||
selectedItem={modelConfig}
|
||||
getIsDisabled={getIsDisabled}
|
||||
isMatch={isMatch}
|
||||
ItemComponent={PickerItemComponent}
|
||||
OptionComponent={PickerItemComponent}
|
||||
GroupHeaderComponent={PickerGroupHeaderComponent}
|
||||
noOptionsFallback={<Text>{t('common.noOptions')}</Text>}
|
||||
noMatchesFallback={<Text>{t('common.noMatches')}</Text>}
|
||||
@@ -233,7 +233,7 @@ const MainModelPicker = memo(() => {
|
||||
MainModelPicker.displayName = 'MainModelPicker';
|
||||
|
||||
const PickerGroupHeaderComponent = memo(
|
||||
({ group }: { group: OptionGroup<AnyModelConfig, { name: string; description: string }> }) => {
|
||||
({ group }: { group: Group<AnyModelConfig, { name: string; description: string }> }) => {
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontSize="sm" fontWeight="semibold">
|
||||
@@ -269,3 +269,29 @@ export const PickerItemComponent = typedMemo(({ item }: { item: AnyModelConfig }
|
||||
);
|
||||
});
|
||||
PickerItemComponent.displayName = 'PickerItemComponent';
|
||||
|
||||
const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
|
||||
'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],
|
||||
'sd-2': ['sd2', 'sd2.0', 'sd2.1', 'sd-2'],
|
||||
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
|
||||
};
|
||||
|
||||
const isMatch = (model: AnyModelConfig, searchTerm: string) => {
|
||||
const regex = getRegex(searchTerm);
|
||||
|
||||
if (
|
||||
model.name.toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.name) ||
|
||||
(BASE_KEYWORDS[model.base] ?? [model.base]).some((kw) => kw.toLowerCase().includes(searchTerm) || regex.test(kw)) ||
|
||||
model.type.toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.type) ||
|
||||
(model.description ?? '').toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.description ?? '') ||
|
||||
model.format.toLowerCase().includes(searchTerm) ||
|
||||
regex.test(model.format)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user