mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): wip model picker
This commit is contained in:
@@ -118,6 +118,8 @@
|
||||
"error": "Error",
|
||||
"error_withCount_one": "{{count}} error",
|
||||
"error_withCount_other": "{{count}} errors",
|
||||
"model_withCount_one": "{{count}} model",
|
||||
"model_withCount_other": "{{count}} models",
|
||||
"file": "File",
|
||||
"folder": "Folder",
|
||||
"format": "format",
|
||||
@@ -768,6 +770,7 @@
|
||||
"description": "Description",
|
||||
"edit": "Edit",
|
||||
"fileSize": "File Size",
|
||||
"filterModels": "Filter models",
|
||||
"fluxRedux": "FLUX Redux",
|
||||
"height": "Height",
|
||||
"huggingFace": "HuggingFace",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import type { InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Divider, Flex, Input, Text } from '@invoke-ai/ui-library';
|
||||
import type { BoxProps, InputProps } from '@invoke-ai/ui-library';
|
||||
import { Flex, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useStateImperative } from 'common/hooks/useStateImperative';
|
||||
import { fixedForwardRef } from 'common/util/fixedForwardRef';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||
import {
|
||||
createContext,
|
||||
useCallback,
|
||||
@@ -26,7 +26,7 @@ export type Group<T extends object, U = any> = {
|
||||
options: T[];
|
||||
};
|
||||
|
||||
const isGroup = <T extends object>(option: T | Group<T>): option is Group<T> => {
|
||||
export const isGroup = <T extends object>(option: T | Group<T>): option is Group<T> => {
|
||||
return option ? 'options' in option && Array.isArray(option.options) : false;
|
||||
};
|
||||
|
||||
@@ -43,10 +43,19 @@ const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option
|
||||
});
|
||||
DefaultOptionComponent.displayName = 'DefaultOptionComponent';
|
||||
|
||||
const DefaultGroupHeaderComponent = typedMemo(<T extends object>({ group }: { group: Group<T> }) => {
|
||||
return <Text fontWeight="bold">{group.id}</Text>;
|
||||
});
|
||||
DefaultGroupHeaderComponent.displayName = 'DefaultGroupHeaderComponent';
|
||||
const DefaultGroupComponent = typedMemo(
|
||||
<T extends object>({ group, children }: PropsWithChildren<{ group: Group<T> }>) => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<Text fontWeight="bold">{group.id}</Text>
|
||||
<Flex flexDir="column" gap={1} w="full">
|
||||
{children}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
DefaultGroupComponent.displayName = 'DefaultGroupComponent';
|
||||
|
||||
const DefaultNoOptionsFallbackComponent = typedMemo(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -68,7 +77,7 @@ const DefaultNoMatchesFallbackComponent = typedMemo(() => {
|
||||
});
|
||||
DefaultNoMatchesFallbackComponent.displayName = 'DefaultNoMatchesFallbackComponent';
|
||||
|
||||
export type PickerProps<T extends object> = {
|
||||
export type PickerProps<T extends object, U> = {
|
||||
options: (T | Group<T>)[];
|
||||
getOptionId: (option: T) => string;
|
||||
isMatch: (option: T, searchTerm: string) => boolean;
|
||||
@@ -80,11 +89,18 @@ export type PickerProps<T extends object> = {
|
||||
SearchBarComponent?: ReturnType<typeof fixedForwardRef<HTMLInputElement, InputProps>>;
|
||||
NoOptionsFallbackComponent?: React.ComponentType;
|
||||
NoMatchesFallbackComponent?: React.ComponentType;
|
||||
OptionComponent?: React.ComponentType<{ option: T }>;
|
||||
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
|
||||
OptionComponent?: React.ComponentType<
|
||||
{
|
||||
option: T;
|
||||
} & BoxProps
|
||||
>;
|
||||
GroupComponent?: React.ComponentType<
|
||||
PropsWithChildren<{ group: Group<T, U>; activeOptionId: string | undefined } & BoxProps>
|
||||
>;
|
||||
};
|
||||
|
||||
type PickerContextState<T extends object> = {
|
||||
type PickerContextState<T extends object, U> = {
|
||||
options: (T | Group<T>)[];
|
||||
getOptionId: (option: T) => string;
|
||||
isMatch: (option: T, searchTerm: string) => boolean;
|
||||
getIsDisabled?: (option: T) => boolean;
|
||||
@@ -93,13 +109,15 @@ type PickerContextState<T extends object> = {
|
||||
SearchBarComponent: ReturnType<typeof fixedForwardRef<HTMLInputElement, InputProps>>;
|
||||
NoOptionsFallbackComponent: React.ComponentType;
|
||||
NoMatchesFallbackComponent: React.ComponentType;
|
||||
OptionComponent: React.ComponentType<{ option: T }>;
|
||||
GroupHeaderComponent: React.ComponentType<{ group: Group<T> }>;
|
||||
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
|
||||
GroupComponent: React.ComponentType<
|
||||
PropsWithChildren<{ group: Group<T, U>; activeOptionId: string | undefined } & BoxProps>
|
||||
>;
|
||||
};
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const PickerContext = createContext<PickerContextState<any> | null>(null);
|
||||
const usePickerContext = <T extends object>(): PickerContextState<T> => {
|
||||
const PickerContext = createContext<PickerContextState<any, any> | null>(null);
|
||||
export const usePickerContext = <T extends object, U>(): PickerContextState<T, U> => {
|
||||
const context = useContext(PickerContext);
|
||||
assert(context !== null, 'usePickerContext must be used within a PickerProvider');
|
||||
return context;
|
||||
@@ -176,7 +194,7 @@ const flattenOptions = <T extends object>(options: (T | Group<T>)[]): T[] => {
|
||||
return flattened;
|
||||
};
|
||||
|
||||
export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
export const Picker = typedMemo(<T extends object, U>(props: PickerProps<T, U>) => {
|
||||
const {
|
||||
getOptionId,
|
||||
options,
|
||||
@@ -190,14 +208,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
NoMatchesFallbackComponent = DefaultNoMatchesFallbackComponent,
|
||||
NoOptionsFallbackComponent = DefaultNoOptionsFallbackComponent,
|
||||
OptionComponent = DefaultOptionComponent,
|
||||
GroupHeaderComponent = DefaultGroupHeaderComponent,
|
||||
GroupComponent = DefaultGroupComponent,
|
||||
} = props;
|
||||
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
|
||||
getFirstOptionId(options, getOptionId)
|
||||
);
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [filteredOptions, setFilteredOptions] = useState<(T | Group<T>)[]>(options);
|
||||
const [filteredOptions, setFilteredOptions] = useState<(T | Group<T, U>)[]>(options);
|
||||
const flattenedOptions = useMemo(() => flattenOptions(options), [options]);
|
||||
const flattenedFilteredOptions = useMemo(() => flattenOptions(filteredOptions), [filteredOptions]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
@@ -213,7 +231,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
setActiveOptionId(getFirstOptionId(options, getOptionId));
|
||||
} else {
|
||||
const lowercasedSearchTerm = searchTerm.toLowerCase();
|
||||
const filtered: (T | Group<T>)[] = [];
|
||||
const filtered: (T | Group<T, U>)[] = [];
|
||||
for (const item of props.options) {
|
||||
if (isGroup(item)) {
|
||||
const filteredItems = item.options.filter((item) => isMatch(item, lowercasedSearchTerm));
|
||||
@@ -346,6 +364,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
const ctx = useMemo(
|
||||
() =>
|
||||
({
|
||||
options,
|
||||
getOptionId,
|
||||
isMatch,
|
||||
getIsDisabled,
|
||||
@@ -355,9 +374,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
NoOptionsFallbackComponent,
|
||||
NoMatchesFallbackComponent,
|
||||
OptionComponent,
|
||||
GroupHeaderComponent,
|
||||
}) satisfies PickerContextState<T>,
|
||||
GroupComponent,
|
||||
}) satisfies PickerContextState<T, U>,
|
||||
[
|
||||
options,
|
||||
getOptionId,
|
||||
isMatch,
|
||||
getIsDisabled,
|
||||
@@ -367,7 +387,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
NoOptionsFallbackComponent,
|
||||
NoMatchesFallbackComponent,
|
||||
OptionComponent,
|
||||
GroupHeaderComponent,
|
||||
GroupComponent,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -385,7 +405,6 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
onKeyDown={onKeyDown}
|
||||
>
|
||||
<SearchBarComponent ref={inputRef} value={searchTerm} onChange={onChangeSearchTerm} />
|
||||
<Divider />
|
||||
<Flex tabIndex={-1} w="full" flexGrow={1}>
|
||||
{flattenedOptions.length === 0 && <NoOptionsFallbackComponent />}
|
||||
{flattenedOptions.length > 0 && flattenedFilteredOptions.length === 0 && <NoMatchesFallbackComponent />}
|
||||
@@ -413,16 +432,16 @@ const DefaultPickerSearchBarComponent = typedMemo(
|
||||
DefaultPickerSearchBarComponent.displayName = 'DefaultPickerSearchBarComponent';
|
||||
|
||||
const PickerList = typedMemo(
|
||||
<T extends object>({
|
||||
<T extends object, U>({
|
||||
items,
|
||||
activeOptionId,
|
||||
selectedItemId,
|
||||
}: {
|
||||
items: (T | Group<T>)[];
|
||||
items: (T | Group<T, U>)[];
|
||||
activeOptionId: string | undefined;
|
||||
selectedItemId: string | undefined;
|
||||
}) => {
|
||||
const { getOptionId, getIsDisabled } = usePickerContext<T>();
|
||||
const { getOptionId, getIsDisabled } = usePickerContext<T, U>();
|
||||
|
||||
if (items.length === 0) {
|
||||
return (
|
||||
@@ -470,63 +489,47 @@ const PickerList = typedMemo(
|
||||
PickerList.displayName = 'PickerList';
|
||||
|
||||
const PickerOptionGroup = typedMemo(
|
||||
<T extends object>({
|
||||
<T extends object, U>({
|
||||
group,
|
||||
activeOptionId,
|
||||
selectedItemId,
|
||||
}: {
|
||||
group: Group<T>;
|
||||
group: Group<T, U>;
|
||||
activeOptionId: string | undefined;
|
||||
selectedItemId: string | undefined;
|
||||
}) => {
|
||||
const { getOptionId, GroupHeaderComponent, getIsDisabled } = usePickerContext<T>();
|
||||
const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext<T, U>();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<GroupHeaderComponent group={group} />
|
||||
<Flex flexDir="column" gap={1} w="full">
|
||||
{group.options.map((item) => {
|
||||
const id = getOptionId(item);
|
||||
return (
|
||||
<PickerOption
|
||||
key={id}
|
||||
id={id}
|
||||
option={item}
|
||||
isActive={id === activeOptionId}
|
||||
isSelected={id === selectedItemId}
|
||||
isDisabled={getIsDisabled?.(item) ?? false}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<GroupComponent group={group} activeOptionId={activeOptionId}>
|
||||
{group.options.map((item) => {
|
||||
const id = getOptionId(item);
|
||||
return (
|
||||
<PickerOption
|
||||
key={id}
|
||||
id={id}
|
||||
option={item}
|
||||
isActive={id === activeOptionId}
|
||||
isSelected={id === selectedItemId}
|
||||
isDisabled={getIsDisabled?.(item) ?? false}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</GroupComponent>
|
||||
);
|
||||
}
|
||||
);
|
||||
PickerOptionGroup.displayName = 'PickerOptionGroup';
|
||||
|
||||
const itemSx: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
flexDir: 'column',
|
||||
p: 2,
|
||||
cursor: 'pointer',
|
||||
borderRadius: 'base',
|
||||
'&[data-selected="true"]': {
|
||||
borderColor: 'invokeBlue.300',
|
||||
borderWidth: 1,
|
||||
},
|
||||
'&[data-active="true"]': {
|
||||
bg: 'base.700',
|
||||
},
|
||||
'&[data-disabled="true"]': {
|
||||
cursor: 'not-allowed',
|
||||
opacity: 0.5,
|
||||
},
|
||||
};
|
||||
|
||||
const PickerOption = typedMemo(
|
||||
<T extends object>(props: { id: string; option: T; isActive: boolean; isSelected: boolean; isDisabled: boolean }) => {
|
||||
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T>();
|
||||
<T extends object, U>(props: {
|
||||
id: string;
|
||||
option: T;
|
||||
isActive: boolean;
|
||||
isSelected: boolean;
|
||||
isDisabled: boolean;
|
||||
}) => {
|
||||
const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext<T, U>();
|
||||
const { id, option, isActive, isDisabled, isSelected } = props;
|
||||
const onPointerMove = useCallback(() => {
|
||||
setActiveOptionId(id);
|
||||
@@ -535,18 +538,16 @@ const PickerOption = typedMemo(
|
||||
onSelectById(id);
|
||||
}, [id, onSelectById]);
|
||||
return (
|
||||
<Box
|
||||
role="option"
|
||||
sx={itemSx}
|
||||
<OptionComponent
|
||||
tabIndex={-1}
|
||||
option={option}
|
||||
id={id}
|
||||
data-disabled={isDisabled}
|
||||
data-selected={isSelected}
|
||||
data-active={isActive}
|
||||
onPointerMove={isDisabled ? undefined : onPointerMove}
|
||||
onClick={isDisabled ? undefined : onClick}
|
||||
>
|
||||
<OptionComponent option={option} />
|
||||
</Box>
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ type Props = {
|
||||
base: BaseModelType;
|
||||
};
|
||||
|
||||
const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
any: 'base',
|
||||
'sd-1': 'green',
|
||||
'sd-2': 'teal',
|
||||
@@ -15,7 +15,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
flux: 'gold',
|
||||
cogview4: 'orange',
|
||||
cogview4: 'red',
|
||||
};
|
||||
|
||||
const ModelBaseBadge = ({ base }: Props) => {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import type { FormLabelProps, InputProps } from '@invoke-ai/ui-library';
|
||||
import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Collapse,
|
||||
Expander,
|
||||
Flex,
|
||||
FormControlGroup,
|
||||
FormLabel,
|
||||
Icon,
|
||||
Input,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
@@ -24,13 +26,14 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
||||
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
|
||||
import { getRegex, Picker } from 'common/components/Picker/Picker';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import { useStateImperative } from 'common/hooks/useStateImperative';
|
||||
import { fixedForwardRef } from 'common/util/fixedForwardRef';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectIsCogView4, selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage';
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamGuidance from 'features/parameters/components/Core/ParamGuidance';
|
||||
@@ -41,11 +44,13 @@ import { UseDefaultSettingsButton } from 'features/parameters/components/MainMod
|
||||
import ParamUpscaleCFGScale from 'features/parameters/components/Upscale/ParamUpscaleCFGScale';
|
||||
import ParamUpscaleScheduler from 'features/parameters/components/Upscale/ParamUpscaleScheduler';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { filesize } from 'filesize';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -121,22 +126,26 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
GenerationSettingsAccordion.displayName = 'GenerationSettingsAccordion';
|
||||
|
||||
const getOptionId = (modelConfig: AnyModelConfig) => modelConfig.key;
|
||||
const getIsDisabled = (modelConfig: AnyModelConfig) => {
|
||||
return modelConfig.base === 'flux';
|
||||
|
||||
type GroupData = {
|
||||
base: BaseModelType;
|
||||
description: string;
|
||||
};
|
||||
|
||||
const MainModelPicker = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [modelConfigs] = useMainModels();
|
||||
const grouped = useMemo<Group<AnyModelConfig, { name: string; description: string }>[]>(() => {
|
||||
const groups: { [base in BaseModelType]?: Group<AnyModelConfig, { name: string; description: string }> } = {};
|
||||
const grouped = useMemo<Group<AnyModelConfig, GroupData>[]>(() => {
|
||||
const groups: {
|
||||
[base in BaseModelType]?: Group<AnyModelConfig, GroupData>;
|
||||
} = {};
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
let group = groups[modelConfig.base];
|
||||
if (!group) {
|
||||
group = {
|
||||
id: modelConfig.base,
|
||||
data: { name: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
|
||||
data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` },
|
||||
options: [],
|
||||
};
|
||||
groups[modelConfig.base] = group;
|
||||
@@ -145,7 +154,7 @@ const MainModelPicker = memo(() => {
|
||||
group.options.push(modelConfig);
|
||||
}
|
||||
|
||||
const sortedGroups: Group<AnyModelConfig, { name: string; description: string }>[] = [];
|
||||
const sortedGroups: Group<AnyModelConfig, GroupData>[] = [];
|
||||
|
||||
if (groups['flux']) {
|
||||
sortedGroups.push(groups['flux']);
|
||||
@@ -210,11 +219,11 @@ const MainModelPicker = memo(() => {
|
||||
<NavigateToModelManagerButton />
|
||||
<UseDefaultSettingsButton />
|
||||
</Flex>
|
||||
<Portal>
|
||||
<Portal appendToParentPortal={false}>
|
||||
<PopoverContent p={0} w={448} h={512}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody p={0} w="full" h="full">
|
||||
<Picker<AnyModelConfig>
|
||||
<Picker<AnyModelConfig, GroupData>
|
||||
handleRef={pickerRef}
|
||||
options={grouped}
|
||||
getOptionId={getOptionId}
|
||||
@@ -222,8 +231,8 @@ const MainModelPicker = memo(() => {
|
||||
selectedItem={modelConfig}
|
||||
// getIsDisabled={getIsDisabled}
|
||||
isMatch={isMatch}
|
||||
OptionComponent={PickerItemComponent}
|
||||
GroupHeaderComponent={PickerGroupHeaderComponent}
|
||||
OptionComponent={PickerOptionComponent}
|
||||
GroupComponent={PickerGroupComponent}
|
||||
SearchBarComponent={SearchBarComponent}
|
||||
/>
|
||||
</PopoverBody>
|
||||
@@ -238,34 +247,127 @@ const SearchBarComponent = typedMemo(
|
||||
fixedForwardRef<HTMLInputElement, InputProps>((props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Input ref={ref} {...props} placeholder={t('common.search')} />
|
||||
<NavigateToModelManagerButton />
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Input ref={ref} {...props} placeholder={t('modelManager.filterModels')} />
|
||||
<NavigateToModelManagerButton />
|
||||
</Flex>
|
||||
<Flex gap={2} alignItems="center"></Flex>
|
||||
</Flex>
|
||||
);
|
||||
})
|
||||
);
|
||||
SearchBarComponent.displayName = 'SearchBarComponent';
|
||||
|
||||
const PickerGroupHeaderComponent = memo(
|
||||
({ group }: { group: Group<AnyModelConfig, { name: string; description: string }> }) => {
|
||||
const toggleButtonSx = {
|
||||
"&[data-expanded='true']": {
|
||||
transform: 'rotate(180deg)',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
const PickerGroupComponent = memo(
|
||||
({
|
||||
group,
|
||||
activeOptionId,
|
||||
children,
|
||||
}: PropsWithChildren<{ group: Group<AnyModelConfig, GroupData>; activeOptionId: string | undefined }>) => {
|
||||
const [isOpen, setIsOpen, getIsOpen] = useStateImperative(true);
|
||||
useEffect(() => {
|
||||
if (group.options.some((option) => option.key === activeOptionId) && !getIsOpen()) {
|
||||
setIsOpen(true);
|
||||
}
|
||||
}, [activeOptionId, getIsOpen, group.options, setIsOpen]);
|
||||
const toggle = useCallback(() => {
|
||||
setIsOpen((prev) => !prev);
|
||||
}, [setIsOpen]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" ps={8}>
|
||||
<Text fontSize="sm" fontWeight="semibold">
|
||||
{`${group.data.name} (${group.options.length} models)`}
|
||||
</Text>
|
||||
<Text color="base.200" fontStyle="italic">
|
||||
{group.data.description}
|
||||
</Text>
|
||||
<Flex
|
||||
flexDir="column"
|
||||
w="full"
|
||||
borderLeftColor={`${BASE_COLOR_MAP[group.data.base]}.300`}
|
||||
borderLeftWidth={4}
|
||||
ps={2}
|
||||
>
|
||||
<GroupHeader group={group} isOpen={isOpen} toggle={toggle} />
|
||||
<Collapse in={isOpen} animateOpacity>
|
||||
<Flex flexDir="column" gap={1} w="full" pb={2}>
|
||||
{children}
|
||||
</Flex>
|
||||
</Collapse>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
PickerGroupHeaderComponent.displayName = 'PickerGroupHeaderComponent';
|
||||
PickerGroupComponent.displayName = 'PickerGroupComponent';
|
||||
|
||||
export const PickerItemComponent = typedMemo(({ option }: { option: AnyModelConfig }) => {
|
||||
const GroupHeader = memo(
|
||||
({
|
||||
group,
|
||||
isOpen,
|
||||
toggle,
|
||||
...rest
|
||||
}: { group: Group<AnyModelConfig, GroupData>; isOpen: boolean; toggle: () => void } & BoxProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex
|
||||
{...rest}
|
||||
role="button"
|
||||
alignItems="center"
|
||||
px={2}
|
||||
onClick={toggle}
|
||||
userSelect="none"
|
||||
position="sticky"
|
||||
top={0}
|
||||
bg="base.800"
|
||||
pb={2}
|
||||
>
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Text fontSize="sm" fontWeight="semibold" color={`${BASE_COLOR_MAP[group.data.base]}.300`}>
|
||||
{MODEL_TYPE_SHORT_MAP[group.data.base]}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.300" noOfLines={1}>
|
||||
{t('common.model_withCount', { count: group.options.length })}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Text color="base.200" fontStyle="italic">
|
||||
{group.data.description}
|
||||
</Text>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
<Icon color="base.300" as={PiCaretDownBold} sx={toggleButtonSx} data-expanded={isOpen} boxSize={4} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
GroupHeader.displayName = 'GroupHeader';
|
||||
|
||||
const optionSx: SystemStyleObject = {
|
||||
p: 2,
|
||||
gap: 2,
|
||||
cursor: 'pointer',
|
||||
borderRadius: 'base',
|
||||
'&[data-selected="true"]': {
|
||||
bg: 'base.700',
|
||||
'&[data-active="true"]': {
|
||||
bg: 'base.650',
|
||||
},
|
||||
},
|
||||
'&[data-active="true"]': {
|
||||
bg: 'base.750',
|
||||
},
|
||||
'&[data-disabled="true"]': {
|
||||
cursor: 'not-allowed',
|
||||
opacity: 0.5,
|
||||
},
|
||||
scrollMarginTop: '42px', // magic number, this is the height of the header
|
||||
};
|
||||
|
||||
export const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => {
|
||||
return (
|
||||
<Flex tabIndex={-1} gap={2}>
|
||||
<Flex {...rest} sx={optionSx}>
|
||||
<ModelImage image_url={option.cover_image} />
|
||||
<Flex flexDir="column" gap={2} flex={1}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
@@ -276,14 +378,13 @@ export const PickerItemComponent = typedMemo(({ option }: { option: AnyModelConf
|
||||
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
|
||||
{filesize(option.file_size)}
|
||||
</Text>
|
||||
<ModelBaseBadge base={option.base} />
|
||||
</Flex>
|
||||
{option.description && <Text color="base.200">{option.description}</Text>}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PickerItemComponent.displayName = 'PickerItemComponent';
|
||||
PickerOptionComponent.displayName = 'PickerItemComponent';
|
||||
|
||||
const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
|
||||
'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],
|
||||
|
||||
Reference in New Issue
Block a user