feat(ui): wip model picker

This commit is contained in:
psychedelicious
2025-04-17 19:49:51 +10:00
parent 9d8a71b362
commit 015dc3ac0d
4 changed files with 212 additions and 107 deletions

View File

@@ -118,6 +118,8 @@
"error": "Error",
"error_withCount_one": "{{count}} error",
"error_withCount_other": "{{count}} errors",
"model_withCount_one": "{{count}} model",
"model_withCount_other": "{{count}} models",
"file": "File",
"folder": "Folder",
"format": "format",
@@ -768,6 +770,7 @@
"description": "Description",
"edit": "Edit",
"fileSize": "File Size",
"filterModels": "Filter models",
"fluxRedux": "FLUX Redux",
"height": "Height",
"huggingFace": "HuggingFace",

View File

@@ -1,11 +1,11 @@
import type { InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Divider, Flex, Input, Text } from '@invoke-ai/ui-library';
import type { BoxProps, InputProps } from '@invoke-ai/ui-library';
import { Flex, Input, Text } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useStateImperative } from 'common/hooks/useStateImperative';
import { fixedForwardRef } from 'common/util/fixedForwardRef';
import { typedMemo } from 'common/util/typedMemo';
import type { ChangeEvent } from 'react';
import type { ChangeEvent, PropsWithChildren } from 'react';
import {
createContext,
useCallback,
@@ -26,7 +26,7 @@ export type Group<T extends object, U = any> = {
options: T[];
};
const isGroup = <T extends object>(option: T | Group<T>): option is Group<T> => {
export const isGroup = <T extends object>(option: T | Group<T>): option is Group<T> => {
return option ? 'options' in option && Array.isArray(option.options) : false;
};
@@ -43,10 +43,19 @@ const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option
});
DefaultOptionComponent.displayName = 'DefaultOptionComponent';
const DefaultGroupHeaderComponent = typedMemo(<T extends object>({ group }: { group: Group<T> }) => {
return <Text fontWeight="bold">{group.id}</Text>;
});
DefaultGroupHeaderComponent.displayName = 'DefaultGroupHeaderComponent';
const DefaultGroupComponent = typedMemo(
<T extends object>({ group, children }: PropsWithChildren<{ group: Group<T> }>) => {
return (
<Flex flexDir="column" gap={2} w="full">
<Text fontWeight="bold">{group.id}</Text>
<Flex flexDir="column" gap={1} w="full">
{children}
</Flex>
</Flex>
);
}
);
DefaultGroupComponent.displayName = 'DefaultGroupComponent';
const DefaultNoOptionsFallbackComponent = typedMemo(() => {
const { t } = useTranslation();
@@ -68,7 +77,7 @@ const DefaultNoMatchesFallbackComponent = typedMemo(() => {
});
DefaultNoMatchesFallbackComponent.displayName = 'DefaultNoMatchesFallbackComponent';
export type PickerProps<T extends object> = {
export type PickerProps<T extends object, U> = {
options: (T | Group<T>)[];
getOptionId: (option: T) => string;
isMatch: (option: T, searchTerm: string) => boolean;
@@ -80,11 +89,18 @@ export type PickerProps<T extends object> = {
SearchBarComponent?: ReturnType<typeof fixedForwardRef<HTMLInputElement, InputProps>>;
NoOptionsFallbackComponent?: React.ComponentType;
NoMatchesFallbackComponent?: React.ComponentType;
OptionComponent?: React.ComponentType<{ option: T }>;
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
OptionComponent?: React.ComponentType<
{
option: T;
} & BoxProps
>;
GroupComponent?: React.ComponentType<
PropsWithChildren<{ group: Group<T, U>; activeOptionId: string | undefined } & BoxProps>
>;
};
type PickerContextState<T extends object> = {
type PickerContextState<T extends object, U> = {
options: (T | Group<T>)[];
getOptionId: (option: T) => string;
isMatch: (option: T, searchTerm: string) => boolean;
getIsDisabled?: (option: T) => boolean;
@@ -93,13 +109,15 @@ type PickerContextState<T extends object> = {
SearchBarComponent: ReturnType<typeof fixedForwardRef<HTMLInputElement, InputProps>>;
NoOptionsFallbackComponent: React.ComponentType;
NoMatchesFallbackComponent: React.ComponentType;
OptionComponent: React.ComponentType<{ option: T }>;
GroupHeaderComponent: React.ComponentType<{ group: Group<T> }>;
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
GroupComponent: React.ComponentType<
PropsWithChildren<{ group: Group<T, U>; activeOptionId: string | undefined } & BoxProps>
>;
};
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const PickerContext = createContext<PickerContextState<any> | null>(null);
const usePickerContext = <T extends object>(): PickerContextState<T> => {
const PickerContext = createContext<PickerContextState<any, any> | null>(null);
export const usePickerContext = <T extends object, U>(): PickerContextState<T, U> => {
const context = useContext(PickerContext);
assert(context !== null, 'usePickerContext must be used within a PickerProvider');
return context;
@@ -176,7 +194,7 @@ const flattenOptions = <T extends object>(options: (T | Group<T>)[]): T[] => {
return flattened;
};
export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>) => {
const {
getOptionId,
options,
@@ -190,14 +208,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
NoMatchesFallbackComponent = DefaultNoMatchesFallbackComponent,
NoOptionsFallbackComponent = DefaultNoOptionsFallbackComponent,
OptionComponent = DefaultOptionComponent,
GroupHeaderComponent = DefaultGroupHeaderComponent,
GroupComponent = DefaultGroupComponent,
} = props;
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
getFirstOptionId(options, getOptionId)
);
const rootRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const [filteredOptions, setFilteredOptions] = useState<(T | Group<T>)[]>(options);
const [filteredOptions, setFilteredOptions] = useState<(T | Group<T, U>)[]>(options);
const flattenedOptions = useMemo(() => flattenOptions(options), [options]);
const flattenedFilteredOptions = useMemo(() => flattenOptions(filteredOptions), [filteredOptions]);
const [searchTerm, setSearchTerm] = useState('');
@@ -213,7 +231,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
setActiveOptionId(getFirstOptionId(options, getOptionId));
} else {
const lowercasedSearchTerm = searchTerm.toLowerCase();
const filtered: (T | Group<T>)[] = [];
const filtered: (T | Group<T, U>)[] = [];
for (const item of props.options) {
if (isGroup(item)) {
const filteredItems = item.options.filter((item) => isMatch(item, lowercasedSearchTerm));
@@ -346,6 +364,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
const ctx = useMemo(
() =>
({
options,
getOptionId,
isMatch,
getIsDisabled,
@@ -355,9 +374,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
NoOptionsFallbackComponent,
NoMatchesFallbackComponent,
OptionComponent,
GroupHeaderComponent,
}) satisfies PickerContextState<T>,
GroupComponent,
}) satisfies PickerContextState<T, U>,
[
options,
getOptionId,
isMatch,
getIsDisabled,
@@ -367,7 +387,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
NoOptionsFallbackComponent,
NoMatchesFallbackComponent,
OptionComponent,
GroupHeaderComponent,
GroupComponent,
]
);
@@ -385,7 +405,6 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
onKeyDown={onKeyDown}
>
<SearchBarComponent ref={inputRef} value={searchTerm} onChange={onChangeSearchTerm} />
<Divider />
<Flex tabIndex={-1} w="full" flexGrow={1}>
{flattenedOptions.length === 0 && <NoOptionsFallbackComponent />}
{flattenedOptions.length > 0 && flattenedFilteredOptions.length === 0 && <NoMatchesFallbackComponent />}
@@ -413,16 +432,16 @@ const DefaultPickerSearchBarComponent = typedMemo(
DefaultPickerSearchBarComponent.displayName = 'DefaultPickerSearchBarComponent';
const PickerList = typedMemo(
<T extends object>({
<T extends object, U>({
items,
activeOptionId,
selectedItemId,
}: {
items: (T | Group<T>)[];
items: (T | Group<T, U>)[];
activeOptionId: string | undefined;
selectedItemId: string | undefined;
}) => {
const { getOptionId, getIsDisabled } = usePickerContext<T>();
const { getOptionId, getIsDisabled } = usePickerContext<T, U>();
if (items.length === 0) {
return (
@@ -470,63 +489,47 @@ const PickerList = typedMemo(
PickerList.displayName = 'PickerList';
const PickerOptionGroup = typedMemo(
<T extends object>({
<T extends object, U>({
group,
activeOptionId,
selectedItemId,
}: {
group: Group<T>;
group: Group<T, U>;
activeOptionId: string | undefined;
selectedItemId: string | undefined;
}) => {
const { getOptionId, GroupHeaderComponent, getIsDisabled } = usePickerContext<T>();
const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext<T, U>();
return (
<Flex flexDir="column" gap={2} w="full">
<GroupHeaderComponent group={group} />
<Flex flexDir="column" gap={1} w="full">
{group.options.map((item) => {
const id = getOptionId(item);
return (
<PickerOption
key={id}
id={id}
option={item}
isActive={id === activeOptionId}
isSelected={id === selectedItemId}
isDisabled={getIsDisabled?.(item) ?? false}
/>
);
})}
</Flex>
</Flex>
<GroupComponent group={group} activeOptionId={activeOptionId}>
{group.options.map((item) => {
const id = getOptionId(item);
return (
<PickerOption
key={id}
id={id}
option={item}
isActive={id === activeOptionId}
isSelected={id === selectedItemId}
isDisabled={getIsDisabled?.(item) ?? false}
/>
);
})}
</GroupComponent>
);
}
);
PickerOptionGroup.displayName = 'PickerOptionGroup';
const itemSx: SystemStyleObject = {
display: 'flex',
flexDir: 'column',
p: 2,
cursor: 'pointer',
borderRadius: 'base',
'&[data-selected="true"]': {
borderColor: 'invokeBlue.300',
borderWidth: 1,
},
'&[data-active="true"]': {
bg: 'base.700',
},
'&[data-disabled="true"]': {
cursor: 'not-allowed',
opacity: 0.5,
},
};
const PickerOption = typedMemo(
<T extends object>(props: { id: string; option: T; isActive: boolean; isSelected: boolean; isDisabled: boolean }) => {
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T>();
<T extends object, U>(props: {
id: string;
option: T;
isActive: boolean;
isSelected: boolean;
isDisabled: boolean;
}) => {
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T, U>();
const { id, option, isActive, isDisabled, isSelected } = props;
const onPointerMove = useCallback(() => {
setActiveOptionId(id);
@@ -535,18 +538,16 @@ const PickerOption = typedMemo(
onSelectById(id);
}, [id, onSelectById]);
return (
<Box
role="option"
sx={itemSx}
<OptionComponent
tabIndex={-1}
option={option}
id={id}
data-disabled={isDisabled}
data-selected={isSelected}
data-active={isActive}
onPointerMove={isDisabled ? undefined : onPointerMove}
onClick={isDisabled ? undefined : onClick}
>
<OptionComponent option={option} />
</Box>
/>
);
}
);

View File

@@ -7,7 +7,7 @@ type Props = {
base: BaseModelType;
};
const BASE_COLOR_MAP: Record<BaseModelType, string> = {
export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
@@ -15,7 +15,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',
cogview4: 'orange',
cogview4: 'red',
};
const ModelBaseBadge = ({ base }: Props) => {

View File

@@ -1,11 +1,13 @@
import type { FormLabelProps, InputProps } from '@invoke-ai/ui-library';
import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
import {
Box,
Button,
Collapse,
Expander,
Flex,
FormControlGroup,
FormLabel,
Icon,
Input,
Popover,
PopoverArrow,
@@ -24,13 +26,14 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
import { getRegex, Picker } from 'common/components/Picker/Picker';
import { useDisclosure } from 'common/hooks/useBoolean';
import { useStateImperative } from 'common/hooks/useStateImperative';
import { fixedForwardRef } from 'common/util/fixedForwardRef';
import { typedMemo } from 'common/util/typedMemo';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectIsCogView4, selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamGuidance from 'features/parameters/components/Core/ParamGuidance';
@@ -41,11 +44,13 @@ import { UseDefaultSettingsButton } from 'features/parameters/components/MainMod
import ParamUpscaleCFGScale from 'features/parameters/components/Upscale/ParamUpscaleCFGScale';
import ParamUpscaleScheduler from 'features/parameters/components/Upscale/ParamUpscaleScheduler';
import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { filesize } from 'filesize';
import { memo, useCallback, useMemo, useRef } from 'react';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
import { useMainModels } from 'services/api/hooks/modelsByType';
@@ -121,22 +126,26 @@ export const GenerationSettingsAccordion = memo(() => {
GenerationSettingsAccordion.displayName = 'GenerationSettingsAccordion';
const getOptionId = (modelConfig: AnyModelConfig) => modelConfig.key;
const getIsDisabled = (modelConfig: AnyModelConfig) => {
return modelConfig.base === 'flux';
type GroupData = {
base: BaseModelType;
description: string;
};
const MainModelPicker = memo(() => {
const { t } = useTranslation();
const [modelConfigs] = useMainModels();
const grouped = useMemo<Group<AnyModelConfig, { name: string; description: string }>[]>(() => {
const groups: { [base in BaseModelType]?: Group<AnyModelConfig, { name: string; description: string }> } = {};
const grouped = useMemo<Group<AnyModelConfig, GroupData>[]>(() => {
const groups: {
[base in BaseModelType]?: Group<AnyModelConfig, GroupData>;
} = {};
for (const modelConfig of modelConfigs) {
let group = groups[modelConfig.base];
if (!group) {
group = {
id: modelConfig.base,
data: { name: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
options: [],
};
groups[modelConfig.base] = group;
@@ -145,7 +154,7 @@ const MainModelPicker = memo(() => {
group.options.push(modelConfig);
}
const sortedGroups: Group<AnyModelConfig, { name: string; description: string }>[] = [];
const sortedGroups: Group<AnyModelConfig, GroupData>[] = [];
if (groups['flux']) {
sortedGroups.push(groups['flux']);
@@ -210,11 +219,11 @@ const MainModelPicker = memo(() => {
<NavigateToModelManagerButton />
<UseDefaultSettingsButton />
</Flex>
<Portal>
<Portal appendToParentPortal={false}>
<PopoverContent p={0} w={448} h={512}>
<PopoverArrow />
<PopoverBody p={0} w="full" h="full">
<Picker<AnyModelConfig>
<Picker<AnyModelConfig, GroupData>
handleRef={pickerRef}
options={grouped}
getOptionId={getOptionId}
@@ -222,8 +231,8 @@ const MainModelPicker = memo(() => {
selectedItem={modelConfig}
// getIsDisabled={getIsDisabled}
isMatch={isMatch}
OptionComponent={PickerItemComponent}
GroupHeaderComponent={PickerGroupHeaderComponent}
OptionComponent={PickerOptionComponent}
GroupComponent={PickerGroupComponent}
SearchBarComponent={SearchBarComponent}
/>
</PopoverBody>
@@ -238,34 +247,127 @@ const SearchBarComponent = typedMemo(
fixedForwardRef<HTMLInputElement, InputProps>((props, ref) => {
const { t } = useTranslation();
return (
<Flex gap={2} alignItems="center">
<Input ref={ref} {...props} placeholder={t('common.search')} />
<NavigateToModelManagerButton />
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="center">
<Input ref={ref} {...props} placeholder={t('modelManager.filterModels')} />
<NavigateToModelManagerButton />
</Flex>
<Flex gap={2} alignItems="center"></Flex>
</Flex>
);
})
);
SearchBarComponent.displayName = 'SearchBarComponent';
const PickerGroupHeaderComponent = memo(
({ group }: { group: Group<AnyModelConfig, { name: string; description: string }> }) => {
const toggleButtonSx = {
"&[data-expanded='true']": {
transform: 'rotate(180deg)',
},
} satisfies SystemStyleObject;
const PickerGroupComponent = memo(
({
group,
activeOptionId,
children,
}: PropsWithChildren<{ group: Group<AnyModelConfig, GroupData>; activeOptionId: string | undefined }>) => {
const [isOpen, setIsOpen, getIsOpen] = useStateImperative(true);
useEffect(() => {
if (group.options.some((option) => option.key === activeOptionId) && !getIsOpen()) {
setIsOpen(true);
}
}, [activeOptionId, getIsOpen, group.options, setIsOpen]);
const toggle = useCallback(() => {
setIsOpen((prev) => !prev);
}, [setIsOpen]);
return (
<Flex flexDir="column" ps={8}>
<Text fontSize="sm" fontWeight="semibold">
{`${group.data.name} (${group.options.length} models)`}
</Text>
<Text color="base.200" fontStyle="italic">
{group.data.description}
</Text>
<Flex
flexDir="column"
w="full"
borderLeftColor={`${BASE_COLOR_MAP[group.data.base]}.300`}
borderLeftWidth={4}
ps={2}
>
<GroupHeader group={group} isOpen={isOpen} toggle={toggle} />
<Collapse in={isOpen} animateOpacity>
<Flex flexDir="column" gap={1} w="full" pb={2}>
{children}
</Flex>
</Collapse>
</Flex>
);
}
);
PickerGroupHeaderComponent.displayName = 'PickerGroupHeaderComponent';
PickerGroupComponent.displayName = 'PickerGroupComponent';
export const PickerItemComponent = typedMemo(({ option }: { option: AnyModelConfig }) => {
const GroupHeader = memo(
({
group,
isOpen,
toggle,
...rest
}: { group: Group<AnyModelConfig, GroupData>; isOpen: boolean; toggle: () => void } & BoxProps) => {
const { t } = useTranslation();
return (
<Flex
{...rest}
role="button"
alignItems="center"
px={2}
onClick={toggle}
userSelect="none"
position="sticky"
top={0}
bg="base.800"
pb={2}
>
<Flex flexDir="column" flex={1}>
<Flex gap={2} alignItems="center">
<Text fontSize="sm" fontWeight="semibold" color={`${BASE_COLOR_MAP[group.data.base]}.300`}>
{MODEL_TYPE_SHORT_MAP[group.data.base]}
</Text>
<Text fontSize="sm" color="base.300" noOfLines={1}>
{t('common.model_withCount', { count: group.options.length })}
</Text>
</Flex>
<Text color="base.200" fontStyle="italic">
{group.data.description}
</Text>
<Spacer />
</Flex>
<Icon color="base.300" as={PiCaretDownBold} sx={toggleButtonSx} data-expanded={isOpen} boxSize={4} />
</Flex>
);
}
);
GroupHeader.displayName = 'GroupHeader';
const optionSx: SystemStyleObject = {
p: 2,
gap: 2,
cursor: 'pointer',
borderRadius: 'base',
'&[data-selected="true"]': {
bg: 'base.700',
'&[data-active="true"]': {
bg: 'base.650',
},
},
'&[data-active="true"]': {
bg: 'base.750',
},
'&[data-disabled="true"]': {
cursor: 'not-allowed',
opacity: 0.5,
},
scrollMarginTop: '42px', // magic number, this is the height of the header
};
export const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => {
return (
<Flex tabIndex={-1} gap={2}>
<Flex {...rest} sx={optionSx}>
<ModelImage image_url={option.cover_image} />
<Flex flexDir="column" gap={2} flex={1}>
<Flex gap={2} alignItems="center">
@@ -276,14 +378,13 @@ export const PickerItemComponent = typedMemo(({ option }: { option: AnyModelConf
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
{filesize(option.file_size)}
</Text>
<ModelBaseBadge base={option.base} />
</Flex>
{option.description && <Text color="base.200">{option.description}</Text>}
</Flex>
</Flex>
);
});
PickerItemComponent.displayName = 'PickerItemComponent';
PickerOptionComponent.displayName = 'PickerItemComponent';
const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],