Add onnx models to the model manager UI

This commit is contained in:
Brandon Rising
2023-07-27 09:37:37 -04:00
parent 4d732e06de
commit 024f92f9a9
12 changed files with 465 additions and 109 deletions

View File

@@ -8,7 +8,9 @@ import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
OnnxModelConfigEntity,
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
@@ -18,7 +20,7 @@ type ModelListProps = {
setSelectedModelId: (name: string | undefined) => void;
};
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive';
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
@@ -39,6 +41,18 @@ const ModelList = (props: ModelListProps) => {
}),
});
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredOnnxModels: modelsFilter(data, 'onnx', nameFilter),
}),
});
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredOliveModels: modelsFilter(data, 'olive', nameFilter),
}),
});
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
@@ -63,10 +77,17 @@ const ModelList = (props: ModelListProps) => {
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
onClick={() => setModelFormatFilter('onnx')}
isChecked={modelFormatFilter === 'onnx'}
>
{t('modelManager.checkpointModels')}
{t('modelManager.onnxModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('olive')}
isChecked={modelFormatFilter === 'olive'}
>
{t('modelManager.oliveModels')}
</IAIButton>
</ButtonGroup>
@@ -118,6 +139,42 @@ const ModelList = (props: ModelListProps) => {
</Flex>
</StyledModelContainer>
)}
{['images', 'olive'].includes(modelFormatFilter) &&
filteredOliveModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Olives
</Text>
{filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'onnx'].includes(modelFormatFilter) &&
filteredOnnxModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Onnx
</Text>
{filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
</Flex>
</Flex>
</Flex>
@@ -127,7 +184,10 @@ const ModelList = (props: ModelListProps) => {
export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
data:
| EntityState<MainModelConfigEntity>
| EntityState<OnnxModelConfigEntity>
| undefined,
model_format: ModelFormat,
nameFilter: string
) => {