From d2e9237740a8e712d83dfee4e56172ea8bbecd52 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:09:43 +1000 Subject: [PATCH] feat(ui): reworked model selection ui (WIP) --- .../common/components/ModelCmdk/ModelCmdk.tsx | 208 +++++++----------- .../ParametersPanelTextToImage.tsx | 11 +- .../src/services/api/hooks/modelsByType.ts | 11 + 3 files changed, 97 insertions(+), 133 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/ModelCmdk/ModelCmdk.tsx b/invokeai/frontend/web/src/common/components/ModelCmdk/ModelCmdk.tsx index b00ca1d00a..3c69235ffd 100644 --- a/invokeai/frontend/web/src/common/components/ModelCmdk/ModelCmdk.tsx +++ b/invokeai/frontend/web/src/common/components/ModelCmdk/ModelCmdk.tsx @@ -1,24 +1,11 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; -import { - Box, - chakra, - Flex, - Input, - Modal, - ModalBody, - ModalContent, - ModalOverlay, - Spacer, - Text, -} from '@invoke-ai/ui-library'; +import { Box, chakra, Flex, Input, Modal, ModalBody, ModalContent, ModalOverlay, Text } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { EMPTY_ARRAY } from 'app/store/constants'; import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { LRUCache } from 'lru-cache'; import { atom } from 'nanostores'; -import type { ChangeEvent } from 'react'; +import type { ChangeEvent, RefObject } from 'react'; import { memo, useCallback, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; @@ -26,7 +13,7 @@ import type { AnyModelConfig } from 'services/api/types'; import { useDebounce } from 'use-debounce'; export type ModelCmdkOptions = { - filter?: (modelConfig: AnyModelConfig) => boolean; + modelConfigs: AnyModelConfig[]; onSelect: (modelConfig: AnyModelConfig) => void; onClose?: () => void; }; @@ -52,14 +39,8 @@ const closeModelCmdk = () => { $modelCmdkState.set({ isOpen: false }); }; -const regexCache = new LRUCache({ max: 1000 }); - -const getRegex = (searchTerm: string) => { - const cachedRegex = regexCache.get(searchTerm); - if (cachedRegex) { - return cachedRegex; - } - const regex = new RegExp( +const getRegex = (searchTerm: string) => + new RegExp( searchTerm .trim() .replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '') @@ -67,24 +48,8 @@ const getRegex = (searchTerm: string) => { .join('.*'), 'gi' ); - regexCache.set(searchTerm, regex); - - return regex; -}; - -const filterCache = new LRUCache({ max: 1000 }); -const getFilter = (model: AnyModelConfig, searchTerm: string) => { - const key = `${model.key}-${searchTerm}`; - const cachedFilter = filterCache.get(key); - if (cachedFilter !== undefined) { - return cachedFilter; - } - - if (!searchTerm) { - filterCache.set(key, true); - return true; - } +const isMatch = (model: AnyModelConfig, searchTerm: string) => { const regex = getRegex(searchTerm); if ( @@ -99,33 +64,12 @@ const getFilter = (model: AnyModelConfig, searchTerm: string) => { model.format.includes(searchTerm) || regex.test(model.format) ) { - filterCache.set(key, true); return true; } - filterCache.set(key, false); return false; }; -const useEnrichedModelConfigs = () => { - const { data } = useGetModelConfigsQuery(); - const models = useMemo(() => { - if (!data || data.ids.length === 0) { - return EMPTY_ARRAY; - } - const allModels = modelConfigsAdapterSelectors.selectAll(data); - const enrichedModels: (AnyModelConfig & { searchableContent: string })[] = allModels.map((model) => { - const searchableContent = [model.name, model.base, model.type, model.format, model.description ?? ''].join(' '); - return { - ...model, - searchableContent, - }; - }); - return enrichedModels; - }, [data]); - return models; -}; - export const useModelCmdk = (options: ModelCmdkOptions) => { const onOpen = useCallback(() => { openModelCmdk(options); @@ -153,21 +97,8 @@ const cmdkRootSx: SystemStyleObject = { }; export const ModelCmdk = memo(() => { - const { t } = useTranslation(); const inputRef = useRef(null); - const [searchTerm, setSearchTerm] = useState(''); const state = useStore($modelCmdkState); - // Filtering the list is expensive - debounce the search term to avoid stutters - const [debouncedSearchTerm] = useDebounce(searchTerm, 300); - - const onChange = useCallback((e: ChangeEvent) => { - setSearchTerm(e.target.value); - }, []); - - const onClose = useCallback(() => { - closeModelCmdk(); - setSearchTerm(''); - }, []); const onSelect = useCallback( (model: AnyModelConfig) => { @@ -176,40 +107,25 @@ export const ModelCmdk = memo(() => { return; } state.onSelect(model); - onClose(); + closeModelCmdk(); }, - [onClose, state] + [state] ); return ( - + {state.isOpen && ( - - - - - - - - - - - - - - - + )} @@ -219,32 +135,55 @@ export const ModelCmdk = memo(() => { ModelCmdk.displayName = 'ModelCmdk'; -const ModelList = memo( +const ModelCommandRoot = memo( (props: { - searchTerm: string; - filter?: (model: AnyModelConfig) => boolean; + inputRef: RefObject; + modelConfigs: AnyModelConfig[]; onSelect: (model: AnyModelConfig) => void; }) => { + const { t } = useTranslation(); + + const { inputRef, modelConfigs, onSelect } = props; + const [searchTerm, setSearchTerm] = useState(''); + // Filtering the list is expensive - debounce the search term to avoid stutters + const [debouncedSearchTerm] = useDebounce(searchTerm, 300); + + const onChange = useCallback((e: ChangeEvent) => { + setSearchTerm(e.target.value); + }, []); + + return ( + + + + + + + + + + + + + + + + ); + } +); +ModelCommandRoot.displayName = 'ModelCommandRoot'; + +const ModelList = memo( + (props: { searchTerm: string; modelConfigs: AnyModelConfig[]; onSelect: (model: AnyModelConfig) => void }) => { const { data } = useGetModelConfigsQuery(); - const filteredModels = useMemo(() => { - if (!data || data.ids.length === 0) { - return EMPTY_ARRAY; - } - const allModels = modelConfigsAdapterSelectors.selectAll(data); - return props.filter ? allModels.filter(props.filter) : allModels; - }, [data, props.filter]); - const results = useMemo(() => { - if (!props.searchTerm) { - return filteredModels; - } - const results: AnyModelConfig[] = []; - for (const model of filteredModels) { - if (getFilter(model, props.searchTerm)) { - results.push(model); - } - } - return results; - }, [filteredModels, props.searchTerm]); const onSelect = useCallback( (key: string) => { if (!data) { @@ -260,6 +199,18 @@ const ModelList = memo( }, [data, props] ); + const results = useMemo(() => { + if (!props.searchTerm) { + return props.modelConfigs; + } + const results: AnyModelConfig[] = []; + for (const model of props.modelConfigs) { + if (isMatch(model, props.searchTerm)) { + results.push(model); + } + } + return results; + }, [props.modelConfigs, props.searchTerm]); return ( <> @@ -273,6 +224,7 @@ const ModelList = memo( ModelList.displayName = 'ModelList'; const cmdkItemSx: SystemStyleObject = { + display: 'flex', flexDir: 'column', py: 1, px: 2, @@ -282,15 +234,21 @@ const cmdkItemSx: SystemStyleObject = { }, }; +const cmdkItemHeaderSx: SystemStyleObject = { + display: 'flex', + gap: 2, + alignItems: 'center', + justifyContent: 'space-between', +}; + const ChakraCommandItem = chakra(CommandItem); const ModelItem = memo((props: { model: AnyModelConfig; onSelect: (key: string) => void }) => { const { model, onSelect } = props; return ( - + {model.name} - {model.base} diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx index 024d5a028e..9f44c4c6a4 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx @@ -1,7 +1,6 @@ import { Box, Button, Flex } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; -import type { ModelCmdkOptions } from 'common/components/ModelCmdk/ModelCmdk'; import { useModelCmdk } from 'common/components/ModelCmdk/ModelCmdk'; import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants'; import { selectIsCogView4, selectIsSDXL } from 'features/controlLayers/store/paramsSlice'; @@ -18,23 +17,19 @@ import { noop } from 'lodash-es'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import type { CSSProperties } from 'react'; import { memo } from 'react'; -import { isNonRefinerMainModelConfig } from 'services/api/types'; +import { useAllModels } from 'services/api/hooks/modelsByType'; const overlayScrollbarsStyles: CSSProperties = { height: '100%', width: '100%', }; -const options: ModelCmdkOptions = { - filter: isNonRefinerMainModelConfig, - onSelect: noop, -}; - const ParametersPanelTextToImage = () => { const isSDXL = useAppSelector(selectIsSDXL); const isCogview4 = useAppSelector(selectIsCogView4); const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen); - const modelCmdk = useModelCmdk(options); + const [modelConfigs] = useAllModels(); + const modelCmdk = useModelCmdk({ onSelect: noop, modelConfigs }); return ( diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index ddd06eff28..0c5f60af59 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -56,7 +56,18 @@ const buildModelsHook = return [modelConfigs, result] as const; }; +export const useAllModels = () => { + const result = useGetModelConfigsQuery(undefined); + const modelConfigs = useMemo(() => { + if (!result.data) { + return EMPTY_ARRAY; + } + return modelConfigsAdapterSelectors.selectAll(result.data); + }, [result.data]); + + return [modelConfigs, result] as const; +}; export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig); export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig); export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);