mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add onnx models to the model manager UI
This commit is contained in:
@@ -8,7 +8,9 @@ import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
OnnxModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
@@ -18,7 +20,7 @@ type ModelListProps = {
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive';
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
|
||||
const ModelList = (props: ModelListProps) => {
|
||||
const { selectedModelId, setSelectedModelId } = props;
|
||||
@@ -39,6 +41,18 @@ const ModelList = (props: ModelListProps) => {
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOnnxModels: modelsFilter(data, 'onnx', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOliveModels: modelsFilter(data, 'olive', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
@@ -63,10 +77,17 @@ const ModelList = (props: ModelListProps) => {
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('checkpoint')}
|
||||
isChecked={modelFormatFilter === 'checkpoint'}
|
||||
onClick={() => setModelFormatFilter('onnx')}
|
||||
isChecked={modelFormatFilter === 'onnx'}
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
{t('modelManager.onnxModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('olive')}
|
||||
isChecked={modelFormatFilter === 'olive'}
|
||||
>
|
||||
{t('modelManager.oliveModels')}
|
||||
</IAIButton>
|
||||
</ButtonGroup>
|
||||
|
||||
@@ -118,6 +139,42 @@ const ModelList = (props: ModelListProps) => {
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'olive'].includes(modelFormatFilter) &&
|
||||
filteredOliveModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Olives
|
||||
</Text>
|
||||
{filteredOliveModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'onnx'].includes(modelFormatFilter) &&
|
||||
filteredOnnxModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Onnx
|
||||
</Text>
|
||||
{filteredOnnxModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
@@ -127,7 +184,10 @@ const ModelList = (props: ModelListProps) => {
|
||||
export default ModelList;
|
||||
|
||||
const modelsFilter = (
|
||||
data: EntityState<MainModelConfigEntity> | undefined,
|
||||
data:
|
||||
| EntityState<MainModelConfigEntity>
|
||||
| EntityState<OnnxModelConfigEntity>
|
||||
| undefined,
|
||||
model_format: ModelFormat,
|
||||
nameFilter: string
|
||||
) => {
|
||||
|
||||
Reference in New Issue
Block a user