mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-31 00:58:11 -05:00
feat(ui): reworked model selection ui (WIP)
This commit is contained in:
@@ -1,24 +1,11 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
chakra,
|
||||
Flex,
|
||||
Input,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalContent,
|
||||
ModalOverlay,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Box, chakra, Flex, Input, Modal, ModalBody, ModalContent, ModalOverlay, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { LRUCache } from 'lru-cache';
|
||||
import { atom } from 'nanostores';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import type { ChangeEvent, RefObject } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
@@ -26,7 +13,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
export type ModelCmdkOptions = {
|
||||
filter?: (modelConfig: AnyModelConfig) => boolean;
|
||||
modelConfigs: AnyModelConfig[];
|
||||
onSelect: (modelConfig: AnyModelConfig) => void;
|
||||
onClose?: () => void;
|
||||
};
|
||||
@@ -52,14 +39,8 @@ const closeModelCmdk = () => {
|
||||
$modelCmdkState.set({ isOpen: false });
|
||||
};
|
||||
|
||||
const regexCache = new LRUCache<string, RegExp>({ max: 1000 });
|
||||
|
||||
const getRegex = (searchTerm: string) => {
|
||||
const cachedRegex = regexCache.get(searchTerm);
|
||||
if (cachedRegex) {
|
||||
return cachedRegex;
|
||||
}
|
||||
const regex = new RegExp(
|
||||
const getRegex = (searchTerm: string) =>
|
||||
new RegExp(
|
||||
searchTerm
|
||||
.trim()
|
||||
.replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '')
|
||||
@@ -67,24 +48,8 @@ const getRegex = (searchTerm: string) => {
|
||||
.join('.*'),
|
||||
'gi'
|
||||
);
|
||||
regexCache.set(searchTerm, regex);
|
||||
|
||||
return regex;
|
||||
};
|
||||
|
||||
const filterCache = new LRUCache<string, boolean>({ max: 1000 });
|
||||
const getFilter = (model: AnyModelConfig, searchTerm: string) => {
|
||||
const key = `${model.key}-${searchTerm}`;
|
||||
const cachedFilter = filterCache.get(key);
|
||||
if (cachedFilter !== undefined) {
|
||||
return cachedFilter;
|
||||
}
|
||||
|
||||
if (!searchTerm) {
|
||||
filterCache.set(key, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
const isMatch = (model: AnyModelConfig, searchTerm: string) => {
|
||||
const regex = getRegex(searchTerm);
|
||||
|
||||
if (
|
||||
@@ -99,33 +64,12 @@ const getFilter = (model: AnyModelConfig, searchTerm: string) => {
|
||||
model.format.includes(searchTerm) ||
|
||||
regex.test(model.format)
|
||||
) {
|
||||
filterCache.set(key, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
filterCache.set(key, false);
|
||||
return false;
|
||||
};
|
||||
|
||||
const useEnrichedModelConfigs = () => {
|
||||
const { data } = useGetModelConfigsQuery();
|
||||
const models = useMemo(() => {
|
||||
if (!data || data.ids.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
const allModels = modelConfigsAdapterSelectors.selectAll(data);
|
||||
const enrichedModels: (AnyModelConfig & { searchableContent: string })[] = allModels.map((model) => {
|
||||
const searchableContent = [model.name, model.base, model.type, model.format, model.description ?? ''].join(' ');
|
||||
return {
|
||||
...model,
|
||||
searchableContent,
|
||||
};
|
||||
});
|
||||
return enrichedModels;
|
||||
}, [data]);
|
||||
return models;
|
||||
};
|
||||
|
||||
export const useModelCmdk = (options: ModelCmdkOptions) => {
|
||||
const onOpen = useCallback(() => {
|
||||
openModelCmdk(options);
|
||||
@@ -153,21 +97,8 @@ const cmdkRootSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
export const ModelCmdk = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const state = useStore($modelCmdkState);
|
||||
// Filtering the list is expensive - debounce the search term to avoid stutters
|
||||
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
}, []);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
closeModelCmdk();
|
||||
setSearchTerm('');
|
||||
}, []);
|
||||
|
||||
const onSelect = useCallback(
|
||||
(model: AnyModelConfig) => {
|
||||
@@ -176,40 +107,25 @@ export const ModelCmdk = memo(() => {
|
||||
return;
|
||||
}
|
||||
state.onSelect(model);
|
||||
onClose();
|
||||
closeModelCmdk();
|
||||
},
|
||||
[onClose, state]
|
||||
[state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal isOpen={state.isOpen} onClose={onClose} useInert={false} initialFocusRef={inputRef} size="xl" isCentered>
|
||||
<Modal
|
||||
isOpen={state.isOpen}
|
||||
onClose={closeModelCmdk}
|
||||
useInert={false}
|
||||
initialFocusRef={inputRef}
|
||||
size="xl"
|
||||
isCentered
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent h="512" maxH="70%">
|
||||
<ModalBody sx={cmdkRootSx}>
|
||||
{state.isOpen && (
|
||||
<CommandRoot loop shouldFilter={false}>
|
||||
<Flex flexDir="column" h="full" gap={2}>
|
||||
<Input ref={inputRef} value={searchTerm} onChange={onChange} placeholder={t('nodes.nodeSearch')} />
|
||||
<Box w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<CommandEmpty>
|
||||
<IAINoContentFallback
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
icon={null}
|
||||
label="No matching items"
|
||||
/>
|
||||
</CommandEmpty>
|
||||
<CommandList>
|
||||
<ModelList searchTerm={debouncedSearchTerm} onSelect={onSelect} filter={state.filter} />
|
||||
</CommandList>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
</Flex>
|
||||
</CommandRoot>
|
||||
<ModelCommandRoot inputRef={inputRef} modelConfigs={state.modelConfigs} onSelect={onSelect} />
|
||||
)}
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
@@ -219,32 +135,55 @@ export const ModelCmdk = memo(() => {
|
||||
|
||||
ModelCmdk.displayName = 'ModelCmdk';
|
||||
|
||||
const ModelList = memo(
|
||||
const ModelCommandRoot = memo(
|
||||
(props: {
|
||||
searchTerm: string;
|
||||
filter?: (model: AnyModelConfig) => boolean;
|
||||
inputRef: RefObject<HTMLInputElement>;
|
||||
modelConfigs: AnyModelConfig[];
|
||||
onSelect: (model: AnyModelConfig) => void;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { inputRef, modelConfigs, onSelect } = props;
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
// Filtering the list is expensive - debounce the search term to avoid stutters
|
||||
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<CommandRoot loop shouldFilter={false}>
|
||||
<Flex flexDir="column" h="full" gap={2}>
|
||||
<Input ref={inputRef} value={searchTerm} onChange={onChange} placeholder={t('nodes.nodeSearch')} />
|
||||
<Box w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<CommandEmpty>
|
||||
<IAINoContentFallback
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
icon={null}
|
||||
label="No matching items"
|
||||
/>
|
||||
</CommandEmpty>
|
||||
<CommandList>
|
||||
<ModelList searchTerm={debouncedSearchTerm} onSelect={onSelect} modelConfigs={modelConfigs} />
|
||||
</CommandList>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
</Flex>
|
||||
</CommandRoot>
|
||||
);
|
||||
}
|
||||
);
|
||||
ModelCommandRoot.displayName = 'ModelCommandRoot';
|
||||
|
||||
const ModelList = memo(
|
||||
(props: { searchTerm: string; modelConfigs: AnyModelConfig[]; onSelect: (model: AnyModelConfig) => void }) => {
|
||||
const { data } = useGetModelConfigsQuery();
|
||||
const filteredModels = useMemo(() => {
|
||||
if (!data || data.ids.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
const allModels = modelConfigsAdapterSelectors.selectAll(data);
|
||||
return props.filter ? allModels.filter(props.filter) : allModels;
|
||||
}, [data, props.filter]);
|
||||
const results = useMemo(() => {
|
||||
if (!props.searchTerm) {
|
||||
return filteredModels;
|
||||
}
|
||||
const results: AnyModelConfig[] = [];
|
||||
for (const model of filteredModels) {
|
||||
if (getFilter(model, props.searchTerm)) {
|
||||
results.push(model);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}, [filteredModels, props.searchTerm]);
|
||||
const onSelect = useCallback(
|
||||
(key: string) => {
|
||||
if (!data) {
|
||||
@@ -260,6 +199,18 @@ const ModelList = memo(
|
||||
},
|
||||
[data, props]
|
||||
);
|
||||
const results = useMemo(() => {
|
||||
if (!props.searchTerm) {
|
||||
return props.modelConfigs;
|
||||
}
|
||||
const results: AnyModelConfig[] = [];
|
||||
for (const model of props.modelConfigs) {
|
||||
if (isMatch(model, props.searchTerm)) {
|
||||
results.push(model);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}, [props.modelConfigs, props.searchTerm]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -273,6 +224,7 @@ const ModelList = memo(
|
||||
ModelList.displayName = 'ModelList';
|
||||
|
||||
const cmdkItemSx: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
flexDir: 'column',
|
||||
py: 1,
|
||||
px: 2,
|
||||
@@ -282,15 +234,21 @@ const cmdkItemSx: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
const cmdkItemHeaderSx: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
gap: 2,
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
};
|
||||
|
||||
const ChakraCommandItem = chakra(CommandItem);
|
||||
|
||||
const ModelItem = memo((props: { model: AnyModelConfig; onSelect: (key: string) => void }) => {
|
||||
const { model, onSelect } = props;
|
||||
return (
|
||||
<ChakraCommandItem value={model.key} onSelect={onSelect} role="button" sx={cmdkItemSx}>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Flex sx={cmdkItemHeaderSx}>
|
||||
<Text fontWeight="semibold">{model.name}</Text>
|
||||
<Spacer />
|
||||
<Text variant="subtext" fontWeight="semibold">
|
||||
{model.base}
|
||||
</Text>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Box, Button, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ModelCmdkOptions } from 'common/components/ModelCmdk/ModelCmdk';
|
||||
import { useModelCmdk } from 'common/components/ModelCmdk/ModelCmdk';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { selectIsCogView4, selectIsSDXL } from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -18,23 +17,19 @@ import { noop } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { useAllModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
const overlayScrollbarsStyles: CSSProperties = {
|
||||
height: '100%',
|
||||
width: '100%',
|
||||
};
|
||||
|
||||
const options: ModelCmdkOptions = {
|
||||
filter: isNonRefinerMainModelConfig,
|
||||
onSelect: noop,
|
||||
};
|
||||
|
||||
const ParametersPanelTextToImage = () => {
|
||||
const isSDXL = useAppSelector(selectIsSDXL);
|
||||
const isCogview4 = useAppSelector(selectIsCogView4);
|
||||
const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen);
|
||||
const modelCmdk = useModelCmdk(options);
|
||||
const [modelConfigs] = useAllModels();
|
||||
const modelCmdk = useModelCmdk({ onSelect: noop, modelConfigs });
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="column" gap={2}>
|
||||
|
||||
@@ -56,7 +56,18 @@ const buildModelsHook =
|
||||
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
export const useAllModels = () => {
|
||||
const result = useGetModelConfigsQuery(undefined);
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!result.data) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(result.data);
|
||||
}, [result.data]);
|
||||
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
|
||||
Reference in New Issue
Block a user