feat(ui): reworked model selection ui (WIP)

This commit is contained in:
psychedelicious
2025-04-14 15:09:43 +10:00
parent b191b706c1
commit d2e9237740
3 changed files with 97 additions and 133 deletions

View File

@@ -1,24 +1,11 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Box,
chakra,
Flex,
Input,
Modal,
ModalBody,
ModalContent,
ModalOverlay,
Spacer,
Text,
} from '@invoke-ai/ui-library';
import { Box, chakra, Flex, Input, Modal, ModalBody, ModalContent, ModalOverlay, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { EMPTY_ARRAY } from 'app/store/constants';
import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { LRUCache } from 'lru-cache';
import { atom } from 'nanostores';
import type { ChangeEvent } from 'react';
import type { ChangeEvent, RefObject } from 'react';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
@@ -26,7 +13,7 @@ import type { AnyModelConfig } from 'services/api/types';
import { useDebounce } from 'use-debounce';
export type ModelCmdkOptions = {
filter?: (modelConfig: AnyModelConfig) => boolean;
modelConfigs: AnyModelConfig[];
onSelect: (modelConfig: AnyModelConfig) => void;
onClose?: () => void;
};
@@ -52,14 +39,8 @@ const closeModelCmdk = () => {
$modelCmdkState.set({ isOpen: false });
};
const regexCache = new LRUCache<string, RegExp>({ max: 1000 });
const getRegex = (searchTerm: string) => {
const cachedRegex = regexCache.get(searchTerm);
if (cachedRegex) {
return cachedRegex;
}
const regex = new RegExp(
const getRegex = (searchTerm: string) =>
new RegExp(
searchTerm
.trim()
.replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '')
@@ -67,24 +48,8 @@ const getRegex = (searchTerm: string) => {
.join('.*'),
'gi'
);
regexCache.set(searchTerm, regex);
return regex;
};
const filterCache = new LRUCache<string, boolean>({ max: 1000 });
const getFilter = (model: AnyModelConfig, searchTerm: string) => {
const key = `${model.key}-${searchTerm}`;
const cachedFilter = filterCache.get(key);
if (cachedFilter !== undefined) {
return cachedFilter;
}
if (!searchTerm) {
filterCache.set(key, true);
return true;
}
const isMatch = (model: AnyModelConfig, searchTerm: string) => {
const regex = getRegex(searchTerm);
if (
@@ -99,33 +64,12 @@ const getFilter = (model: AnyModelConfig, searchTerm: string) => {
model.format.includes(searchTerm) ||
regex.test(model.format)
) {
filterCache.set(key, true);
return true;
}
filterCache.set(key, false);
return false;
};
const useEnrichedModelConfigs = () => {
const { data } = useGetModelConfigsQuery();
const models = useMemo(() => {
if (!data || data.ids.length === 0) {
return EMPTY_ARRAY;
}
const allModels = modelConfigsAdapterSelectors.selectAll(data);
const enrichedModels: (AnyModelConfig & { searchableContent: string })[] = allModels.map((model) => {
const searchableContent = [model.name, model.base, model.type, model.format, model.description ?? ''].join(' ');
return {
...model,
searchableContent,
};
});
return enrichedModels;
}, [data]);
return models;
};
export const useModelCmdk = (options: ModelCmdkOptions) => {
const onOpen = useCallback(() => {
openModelCmdk(options);
@@ -153,21 +97,8 @@ const cmdkRootSx: SystemStyleObject = {
};
export const ModelCmdk = memo(() => {
const { t } = useTranslation();
const inputRef = useRef<HTMLInputElement>(null);
const [searchTerm, setSearchTerm] = useState('');
const state = useStore($modelCmdkState);
// Filtering the list is expensive - debounce the search term to avoid stutters
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setSearchTerm(e.target.value);
}, []);
const onClose = useCallback(() => {
closeModelCmdk();
setSearchTerm('');
}, []);
const onSelect = useCallback(
(model: AnyModelConfig) => {
@@ -176,40 +107,25 @@ export const ModelCmdk = memo(() => {
return;
}
state.onSelect(model);
onClose();
closeModelCmdk();
},
[onClose, state]
[state]
);
return (
<Modal isOpen={state.isOpen} onClose={onClose} useInert={false} initialFocusRef={inputRef} size="xl" isCentered>
<Modal
isOpen={state.isOpen}
onClose={closeModelCmdk}
useInert={false}
initialFocusRef={inputRef}
size="xl"
isCentered
>
<ModalOverlay />
<ModalContent h="512" maxH="70%">
<ModalBody sx={cmdkRootSx}>
{state.isOpen && (
<CommandRoot loop shouldFilter={false}>
<Flex flexDir="column" h="full" gap={2}>
<Input ref={inputRef} value={searchTerm} onChange={onChange} placeholder={t('nodes.nodeSearch')} />
<Box w="full" h="full">
<ScrollableContent>
<CommandEmpty>
<IAINoContentFallback
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
icon={null}
label="No matching items"
/>
</CommandEmpty>
<CommandList>
<ModelList searchTerm={debouncedSearchTerm} onSelect={onSelect} filter={state.filter} />
</CommandList>
</ScrollableContent>
</Box>
</Flex>
</CommandRoot>
<ModelCommandRoot inputRef={inputRef} modelConfigs={state.modelConfigs} onSelect={onSelect} />
)}
</ModalBody>
</ModalContent>
@@ -219,32 +135,55 @@ export const ModelCmdk = memo(() => {
ModelCmdk.displayName = 'ModelCmdk';
const ModelList = memo(
const ModelCommandRoot = memo(
(props: {
searchTerm: string;
filter?: (model: AnyModelConfig) => boolean;
inputRef: RefObject<HTMLInputElement>;
modelConfigs: AnyModelConfig[];
onSelect: (model: AnyModelConfig) => void;
}) => {
const { t } = useTranslation();
const { inputRef, modelConfigs, onSelect } = props;
const [searchTerm, setSearchTerm] = useState('');
// Filtering the list is expensive - debounce the search term to avoid stutters
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setSearchTerm(e.target.value);
}, []);
return (
<CommandRoot loop shouldFilter={false}>
<Flex flexDir="column" h="full" gap={2}>
<Input ref={inputRef} value={searchTerm} onChange={onChange} placeholder={t('nodes.nodeSearch')} />
<Box w="full" h="full">
<ScrollableContent>
<CommandEmpty>
<IAINoContentFallback
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
icon={null}
label="No matching items"
/>
</CommandEmpty>
<CommandList>
<ModelList searchTerm={debouncedSearchTerm} onSelect={onSelect} modelConfigs={modelConfigs} />
</CommandList>
</ScrollableContent>
</Box>
</Flex>
</CommandRoot>
);
}
);
ModelCommandRoot.displayName = 'ModelCommandRoot';
const ModelList = memo(
(props: { searchTerm: string; modelConfigs: AnyModelConfig[]; onSelect: (model: AnyModelConfig) => void }) => {
const { data } = useGetModelConfigsQuery();
const filteredModels = useMemo(() => {
if (!data || data.ids.length === 0) {
return EMPTY_ARRAY;
}
const allModels = modelConfigsAdapterSelectors.selectAll(data);
return props.filter ? allModels.filter(props.filter) : allModels;
}, [data, props.filter]);
const results = useMemo(() => {
if (!props.searchTerm) {
return filteredModels;
}
const results: AnyModelConfig[] = [];
for (const model of filteredModels) {
if (getFilter(model, props.searchTerm)) {
results.push(model);
}
}
return results;
}, [filteredModels, props.searchTerm]);
const onSelect = useCallback(
(key: string) => {
if (!data) {
@@ -260,6 +199,18 @@ const ModelList = memo(
},
[data, props]
);
const results = useMemo(() => {
if (!props.searchTerm) {
return props.modelConfigs;
}
const results: AnyModelConfig[] = [];
for (const model of props.modelConfigs) {
if (isMatch(model, props.searchTerm)) {
results.push(model);
}
}
return results;
}, [props.modelConfigs, props.searchTerm]);
return (
<>
@@ -273,6 +224,7 @@ const ModelList = memo(
ModelList.displayName = 'ModelList';
const cmdkItemSx: SystemStyleObject = {
display: 'flex',
flexDir: 'column',
py: 1,
px: 2,
@@ -282,15 +234,21 @@ const cmdkItemSx: SystemStyleObject = {
},
};
const cmdkItemHeaderSx: SystemStyleObject = {
display: 'flex',
gap: 2,
alignItems: 'center',
justifyContent: 'space-between',
};
const ChakraCommandItem = chakra(CommandItem);
const ModelItem = memo((props: { model: AnyModelConfig; onSelect: (key: string) => void }) => {
const { model, onSelect } = props;
return (
<ChakraCommandItem value={model.key} onSelect={onSelect} role="button" sx={cmdkItemSx}>
<Flex alignItems="center" gap={2}>
<Flex sx={cmdkItemHeaderSx}>
<Text fontWeight="semibold">{model.name}</Text>
<Spacer />
<Text variant="subtext" fontWeight="semibold">
{model.base}
</Text>

View File

@@ -1,7 +1,6 @@
import { Box, Button, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import type { ModelCmdkOptions } from 'common/components/ModelCmdk/ModelCmdk';
import { useModelCmdk } from 'common/components/ModelCmdk/ModelCmdk';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { selectIsCogView4, selectIsSDXL } from 'features/controlLayers/store/paramsSlice';
@@ -18,23 +17,19 @@ import { noop } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { useAllModels } from 'services/api/hooks/modelsByType';
const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
width: '100%',
};
const options: ModelCmdkOptions = {
filter: isNonRefinerMainModelConfig,
onSelect: noop,
};
const ParametersPanelTextToImage = () => {
const isSDXL = useAppSelector(selectIsSDXL);
const isCogview4 = useAppSelector(selectIsCogView4);
const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen);
const modelCmdk = useModelCmdk(options);
const [modelConfigs] = useAllModels();
const modelCmdk = useModelCmdk({ onSelect: noop, modelConfigs });
return (
<Flex w="full" h="full" flexDir="column" gap={2}>

View File

@@ -56,7 +56,18 @@ const buildModelsHook =
return [modelConfigs, result] as const;
};
export const useAllModels = () => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(result.data);
}, [result.data]);
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);