From 015dc3ac0d874057b684816a9494c5f662056450 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:49:51 +1000 Subject: [PATCH] feat(ui): wip model picker --- invokeai/frontend/web/public/locales/en.json | 3 + .../src/common/components/Picker/Picker.tsx | 151 ++++++++-------- .../ModelManagerPanel/ModelBaseBadge.tsx | 4 +- .../GenerationSettingsAccordion.tsx | 161 ++++++++++++++---- 4 files changed, 212 insertions(+), 107 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1a5d9af0c3..d472290749 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -118,6 +118,8 @@ "error": "Error", "error_withCount_one": "{{count}} error", "error_withCount_other": "{{count}} errors", + "model_withCount_one": "{{count}} model", + "model_withCount_other": "{{count}} models", "file": "File", "folder": "Folder", "format": "format", @@ -768,6 +770,7 @@ "description": "Description", "edit": "Edit", "fileSize": "File Size", + "filterModels": "Filter models", "fluxRedux": "FLUX Redux", "height": "Height", "huggingFace": "HuggingFace", diff --git a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx index c5b15659e7..5ead2f88e4 100644 --- a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx +++ b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx @@ -1,11 +1,11 @@ -import type { InputProps, SystemStyleObject } from '@invoke-ai/ui-library'; -import { Box, Divider, Flex, Input, Text } from '@invoke-ai/ui-library'; +import type { BoxProps, InputProps } from '@invoke-ai/ui-library'; +import { Flex, Input, Text } from '@invoke-ai/ui-library'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { useStateImperative } from 'common/hooks/useStateImperative'; import { fixedForwardRef } from 'common/util/fixedForwardRef'; import { typedMemo } from 'common/util/typedMemo'; -import type { ChangeEvent } from 'react'; +import type { ChangeEvent, PropsWithChildren } from 'react'; import { createContext, useCallback, @@ -26,7 +26,7 @@ export type Group = { options: T[]; }; -const isGroup = (option: T | Group): option is Group => { +export const isGroup = (option: T | Group): option is Group => { return option ? 'options' in option && Array.isArray(option.options) : false; }; @@ -43,10 +43,19 @@ const DefaultOptionComponent = typedMemo(({ option }: { option }); DefaultOptionComponent.displayName = 'DefaultOptionComponent'; -const DefaultGroupHeaderComponent = typedMemo(({ group }: { group: Group }) => { - return {group.id}; -}); -DefaultGroupHeaderComponent.displayName = 'DefaultGroupHeaderComponent'; +const DefaultGroupComponent = typedMemo( + ({ group, children }: PropsWithChildren<{ group: Group }>) => { + return ( + + {group.id} + + {children} + + + ); + } +); +DefaultGroupComponent.displayName = 'DefaultGroupComponent'; const DefaultNoOptionsFallbackComponent = typedMemo(() => { const { t } = useTranslation(); @@ -68,7 +77,7 @@ const DefaultNoMatchesFallbackComponent = typedMemo(() => { }); DefaultNoMatchesFallbackComponent.displayName = 'DefaultNoMatchesFallbackComponent'; -export type PickerProps = { +export type PickerProps = { options: (T | Group)[]; getOptionId: (option: T) => string; isMatch: (option: T, searchTerm: string) => boolean; @@ -80,11 +89,18 @@ export type PickerProps = { SearchBarComponent?: ReturnType>; NoOptionsFallbackComponent?: React.ComponentType; NoMatchesFallbackComponent?: React.ComponentType; - OptionComponent?: React.ComponentType<{ option: T }>; - GroupHeaderComponent?: React.ComponentType<{ group: Group }>; + OptionComponent?: React.ComponentType< + { + option: T; + } & BoxProps + >; + GroupComponent?: React.ComponentType< + PropsWithChildren<{ group: Group; activeOptionId: string | undefined } & BoxProps> + >; }; -type PickerContextState = { +type PickerContextState = { + options: (T | Group)[]; getOptionId: (option: T) => string; isMatch: (option: T, searchTerm: string) => boolean; getIsDisabled?: (option: T) => boolean; @@ -93,13 +109,15 @@ type PickerContextState = { SearchBarComponent: ReturnType>; NoOptionsFallbackComponent: React.ComponentType; NoMatchesFallbackComponent: React.ComponentType; - OptionComponent: React.ComponentType<{ option: T }>; - GroupHeaderComponent: React.ComponentType<{ group: Group }>; + OptionComponent: React.ComponentType<{ option: T } & BoxProps>; + GroupComponent: React.ComponentType< + PropsWithChildren<{ group: Group; activeOptionId: string | undefined } & BoxProps> + >; }; /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const PickerContext = createContext | null>(null); -const usePickerContext = (): PickerContextState => { +const PickerContext = createContext | null>(null); +export const usePickerContext = (): PickerContextState => { const context = useContext(PickerContext); assert(context !== null, 'usePickerContext must be used within a PickerProvider'); return context; @@ -176,7 +194,7 @@ const flattenOptions = (options: (T | Group)[]): T[] => { return flattened; }; -export const Picker = typedMemo((props: PickerProps) => { +export const Picker = typedMemo((props: PickerProps) => { const { getOptionId, options, @@ -190,14 +208,14 @@ export const Picker = typedMemo((props: PickerProps) => { NoMatchesFallbackComponent = DefaultNoMatchesFallbackComponent, NoOptionsFallbackComponent = DefaultNoOptionsFallbackComponent, OptionComponent = DefaultOptionComponent, - GroupHeaderComponent = DefaultGroupHeaderComponent, + GroupComponent = DefaultGroupComponent, } = props; const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() => getFirstOptionId(options, getOptionId) ); const rootRef = useRef(null); const inputRef = useRef(null); - const [filteredOptions, setFilteredOptions] = useState<(T | Group)[]>(options); + const [filteredOptions, setFilteredOptions] = useState<(T | Group)[]>(options); const flattenedOptions = useMemo(() => flattenOptions(options), [options]); const flattenedFilteredOptions = useMemo(() => flattenOptions(filteredOptions), [filteredOptions]); const [searchTerm, setSearchTerm] = useState(''); @@ -213,7 +231,7 @@ export const Picker = typedMemo((props: PickerProps) => { setActiveOptionId(getFirstOptionId(options, getOptionId)); } else { const lowercasedSearchTerm = searchTerm.toLowerCase(); - const filtered: (T | Group)[] = []; + const filtered: (T | Group)[] = []; for (const item of props.options) { if (isGroup(item)) { const filteredItems = item.options.filter((item) => isMatch(item, lowercasedSearchTerm)); @@ -346,6 +364,7 @@ export const Picker = typedMemo((props: PickerProps) => { const ctx = useMemo( () => ({ + options, getOptionId, isMatch, getIsDisabled, @@ -355,9 +374,10 @@ export const Picker = typedMemo((props: PickerProps) => { NoOptionsFallbackComponent, NoMatchesFallbackComponent, OptionComponent, - GroupHeaderComponent, - }) satisfies PickerContextState, + GroupComponent, + }) satisfies PickerContextState, [ + options, getOptionId, isMatch, getIsDisabled, @@ -367,7 +387,7 @@ export const Picker = typedMemo((props: PickerProps) => { NoOptionsFallbackComponent, NoMatchesFallbackComponent, OptionComponent, - GroupHeaderComponent, + GroupComponent, ] ); @@ -385,7 +405,6 @@ export const Picker = typedMemo((props: PickerProps) => { onKeyDown={onKeyDown} > - {flattenedOptions.length === 0 && } {flattenedOptions.length > 0 && flattenedFilteredOptions.length === 0 && } @@ -413,16 +432,16 @@ const DefaultPickerSearchBarComponent = typedMemo( DefaultPickerSearchBarComponent.displayName = 'DefaultPickerSearchBarComponent'; const PickerList = typedMemo( - ({ + ({ items, activeOptionId, selectedItemId, }: { - items: (T | Group)[]; + items: (T | Group)[]; activeOptionId: string | undefined; selectedItemId: string | undefined; }) => { - const { getOptionId, getIsDisabled } = usePickerContext(); + const { getOptionId, getIsDisabled } = usePickerContext(); if (items.length === 0) { return ( @@ -470,63 +489,47 @@ const PickerList = typedMemo( PickerList.displayName = 'PickerList'; const PickerOptionGroup = typedMemo( - ({ + ({ group, activeOptionId, selectedItemId, }: { - group: Group; + group: Group; activeOptionId: string | undefined; selectedItemId: string | undefined; }) => { - const { getOptionId, GroupHeaderComponent, getIsDisabled } = usePickerContext(); + const { getOptionId, GroupComponent, getIsDisabled } = usePickerContext(); return ( - - - - {group.options.map((item) => { - const id = getOptionId(item); - return ( - - ); - })} - - + + {group.options.map((item) => { + const id = getOptionId(item); + return ( + + ); + })} + ); } ); PickerOptionGroup.displayName = 'PickerOptionGroup'; -const itemSx: SystemStyleObject = { - display: 'flex', - flexDir: 'column', - p: 2, - cursor: 'pointer', - borderRadius: 'base', - '&[data-selected="true"]': { - borderColor: 'invokeBlue.300', - borderWidth: 1, - }, - '&[data-active="true"]': { - bg: 'base.700', - }, - '&[data-disabled="true"]': { - cursor: 'not-allowed', - opacity: 0.5, - }, -}; - const PickerOption = typedMemo( - (props: { id: string; option: T; isActive: boolean; isSelected: boolean; isDisabled: boolean }) => { - const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext(); + (props: { + id: string; + option: T; + isActive: boolean; + isSelected: boolean; + isDisabled: boolean; + }) => { + const { OptionComponent, setActiveOptionId, onSelectById } = usePickerContext(); const { id, option, isActive, isDisabled, isSelected } = props; const onPointerMove = useCallback(() => { setActiveOptionId(id); @@ -535,18 +538,16 @@ const PickerOption = typedMemo( onSelectById(id); }, [id, onSelectById]); return ( - - - + /> ); } ); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx index d5e25d12ab..fffef0601b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx @@ -7,7 +7,7 @@ type Props = { base: BaseModelType; }; -const BASE_COLOR_MAP: Record = { +export const BASE_COLOR_MAP: Record = { any: 'base', 'sd-1': 'green', 'sd-2': 'teal', @@ -15,7 +15,7 @@ const BASE_COLOR_MAP: Record = { sdxl: 'invokeBlue', 'sdxl-refiner': 'invokeBlue', flux: 'gold', - cogview4: 'orange', + cogview4: 'red', }; const ModelBaseBadge = ({ base }: Props) => { diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 89f5a3ee5a..07f3227586 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -1,11 +1,13 @@ -import type { FormLabelProps, InputProps } from '@invoke-ai/ui-library'; +import type { BoxProps, FormLabelProps, InputProps, SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, Button, + Collapse, Expander, Flex, FormControlGroup, FormLabel, + Icon, Input, Popover, PopoverArrow, @@ -24,13 +26,14 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker'; import { getRegex, Picker } from 'common/components/Picker/Picker'; import { useDisclosure } from 'common/hooks/useBoolean'; +import { useStateImperative } from 'common/hooks/useStateImperative'; import { fixedForwardRef } from 'common/util/fixedForwardRef'; import { typedMemo } from 'common/util/typedMemo'; import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { selectIsCogView4, selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice'; import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; -import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; +import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage'; import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; import ParamGuidance from 'features/parameters/components/Core/ParamGuidance'; @@ -41,11 +44,13 @@ import { UseDefaultSettingsButton } from 'features/parameters/components/MainMod import ParamUpscaleCFGScale from 'features/parameters/components/Upscale/ParamUpscaleCFGScale'; import ParamUpscaleScheduler from 'features/parameters/components/Upscale/ParamUpscaleScheduler'; import { modelSelected } from 'features/parameters/store/actions'; +import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { filesize } from 'filesize'; -import { memo, useCallback, useMemo, useRef } from 'react'; +import type { PropsWithChildren } from 'react'; +import { memo, useCallback, useEffect, useMemo, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { PiCaretDownBold } from 'react-icons/pi'; import { useMainModels } from 'services/api/hooks/modelsByType'; @@ -121,22 +126,26 @@ export const GenerationSettingsAccordion = memo(() => { GenerationSettingsAccordion.displayName = 'GenerationSettingsAccordion'; const getOptionId = (modelConfig: AnyModelConfig) => modelConfig.key; -const getIsDisabled = (modelConfig: AnyModelConfig) => { - return modelConfig.base === 'flux'; + +type GroupData = { + base: BaseModelType; + description: string; }; const MainModelPicker = memo(() => { const { t } = useTranslation(); const [modelConfigs] = useMainModels(); - const grouped = useMemo[]>(() => { - const groups: { [base in BaseModelType]?: Group } = {}; + const grouped = useMemo[]>(() => { + const groups: { + [base in BaseModelType]?: Group; + } = {}; for (const modelConfig of modelConfigs) { let group = groups[modelConfig.base]; if (!group) { group = { id: modelConfig.base, - data: { name: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` }, + data: { base: modelConfig.base, description: `A brief description of ${modelConfig.base} models.` }, options: [], }; groups[modelConfig.base] = group; @@ -145,7 +154,7 @@ const MainModelPicker = memo(() => { group.options.push(modelConfig); } - const sortedGroups: Group[] = []; + const sortedGroups: Group[] = []; if (groups['flux']) { sortedGroups.push(groups['flux']); @@ -210,11 +219,11 @@ const MainModelPicker = memo(() => { - + - + handleRef={pickerRef} options={grouped} getOptionId={getOptionId} @@ -222,8 +231,8 @@ const MainModelPicker = memo(() => { selectedItem={modelConfig} // getIsDisabled={getIsDisabled} isMatch={isMatch} - OptionComponent={PickerItemComponent} - GroupHeaderComponent={PickerGroupHeaderComponent} + OptionComponent={PickerOptionComponent} + GroupComponent={PickerGroupComponent} SearchBarComponent={SearchBarComponent} /> @@ -238,34 +247,127 @@ const SearchBarComponent = typedMemo( fixedForwardRef((props, ref) => { const { t } = useTranslation(); return ( - - - + + + + + + ); }) ); SearchBarComponent.displayName = 'SearchBarComponent'; -const PickerGroupHeaderComponent = memo( - ({ group }: { group: Group }) => { +const toggleButtonSx = { + "&[data-expanded='true']": { + transform: 'rotate(180deg)', + }, +} satisfies SystemStyleObject; + +const PickerGroupComponent = memo( + ({ + group, + activeOptionId, + children, + }: PropsWithChildren<{ group: Group; activeOptionId: string | undefined }>) => { + const [isOpen, setIsOpen, getIsOpen] = useStateImperative(true); + useEffect(() => { + if (group.options.some((option) => option.key === activeOptionId) && !getIsOpen()) { + setIsOpen(true); + } + }, [activeOptionId, getIsOpen, group.options, setIsOpen]); + const toggle = useCallback(() => { + setIsOpen((prev) => !prev); + }, [setIsOpen]); + return ( - - - {`${group.data.name} (${group.options.length} models)`} - - - {group.data.description} - + + + + + {children} + + ); } ); -PickerGroupHeaderComponent.displayName = 'PickerGroupHeaderComponent'; +PickerGroupComponent.displayName = 'PickerGroupComponent'; -export const PickerItemComponent = typedMemo(({ option }: { option: AnyModelConfig }) => { +const GroupHeader = memo( + ({ + group, + isOpen, + toggle, + ...rest + }: { group: Group; isOpen: boolean; toggle: () => void } & BoxProps) => { + const { t } = useTranslation(); + + return ( + + + + + {MODEL_TYPE_SHORT_MAP[group.data.base]} + + + {t('common.model_withCount', { count: group.options.length })} + + + + {group.data.description} + + + + + + ); + } +); +GroupHeader.displayName = 'GroupHeader'; + +const optionSx: SystemStyleObject = { + p: 2, + gap: 2, + cursor: 'pointer', + borderRadius: 'base', + '&[data-selected="true"]': { + bg: 'base.700', + '&[data-active="true"]': { + bg: 'base.650', + }, + }, + '&[data-active="true"]': { + bg: 'base.750', + }, + '&[data-disabled="true"]': { + cursor: 'not-allowed', + opacity: 0.5, + }, + scrollMarginTop: '42px', // magic number, this is the height of the header +}; + +export const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => { return ( - + @@ -276,14 +378,13 @@ export const PickerItemComponent = typedMemo(({ option }: { option: AnyModelConf {filesize(option.file_size)} - {option.description && {option.description}} ); }); -PickerItemComponent.displayName = 'PickerItemComponent'; +PickerOptionComponent.displayName = 'PickerItemComponent'; const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = { 'sd-1': ['sd1', 'sd1.4', 'sd1.5', 'sd-1'],