feat(ui): update UI to use new model config backend

- Update all queries
- Remove Advanced Add
- Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type)
- Update model tags so the list refreshes when a model installs
- Rename some queries, components, variables, types to match backend
- Fix divide-by-zero in install queue
This commit is contained in:
psychedelicious
2024-03-05 19:04:13 +11:00
parent 48119d9010
commit 99407c899f
30 changed files with 993 additions and 1824 deletions

View File

@@ -1,228 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
location: string;
};
export const SimpleImport = () => {
export const InstallModelForm = () => {
const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation();

View File

@@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
import { ImportQueueItem } from './ImportQueueItem';
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
export const ImportQueue = () => {
export const ModelInstallQueue = () => {
const dispatch = useAppDispatch();
const { data } = useGetModelImportsQuery();
const { data } = useListModelInstallsQuery();
const [pruneModelImports] = usePruneModelImportsMutation();
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
const pruneQueue = useCallback(() => {
pruneModelImports()
const pruneCompletedModelInstalls = useCallback(() => {
_pruneCompletedModelInstalls()
.unwrap()
.then((_) => {
dispatch(
@@ -41,7 +41,7 @@ export const ImportQueue = () => {
);
}
});
}, [pruneModelImports, dispatch]);
}, [_pruneCompletedModelInstalls, dispatch]);
const pruneAvailable = useMemo(() => {
return data?.some(
@@ -53,14 +53,19 @@ export const ImportQueue = () => {
<Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text>
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
<Button
size="sm"
isDisabled={!pruneAvailable}
onClick={pruneCompletedModelInstalls}
tooltip={t('modelManager.pruneTooltip')}
>
{t('modelManager.prune')}
</Button>
</Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent>
<Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
</Flex>
</ScrollableContent>
</Box>

View File

@@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
};
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
const ModelInstallQueueBadge = ({
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const { t } = useTranslation();
if (!status || !Object.keys(STATUSES).includes(status)) {
return <></>;
return null;
}
return (
@@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
</Tooltip>
);
};
export default memo(ImportQueueBadge);
export default memo(ModelInstallQueueBadge);

View File

@@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge';
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
type ModelListItemProps = {
model: ModelInstallJob;
installJob: ModelInstallJob;
};
const formatBytes = (bytes: number) => {
@@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
return `${bytes.toFixed(2)} ${units[i]}`;
};
export const ImportQueueItem = (props: ModelListItemProps) => {
const { model } = props;
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const { installJob } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const [deleteImportModel] = useCancelModelInstallMutation();
const source = useMemo(() => {
if (model.source.type === 'hf') {
return model.source as HFModelSource;
} else if (model.source.type === 'local') {
return model.source as LocalModelSource;
} else if (model.source.type === 'url') {
return model.source as URLModelSource;
if (installJob.source.type === 'hf') {
return installJob.source as HFModelSource;
} else if (installJob.source.type === 'local') {
return installJob.source as LocalModelSource;
} else if (installJob.source.type === 'url') {
return installJob.source as URLModelSource;
} else {
return model.source as LocalModelSource;
return installJob.source as LocalModelSource;
}
}, [model.source]);
}, [installJob.source]);
const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id)
deleteImportModel(installJob.id)
.unwrap()
.then((_) => {
dispatch(
@@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
);
}
});
}, [deleteImportModel, model, dispatch]);
}, [deleteImportModel, installJob, dispatch]);
const modelName = useMemo(() => {
switch (source.type) {
@@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [source]);
const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) {
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
return null;
}
if (installJob.total_bytes === 0) {
return 0;
}
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
return (installJob.bytes / installJob.total_bytes) * 100;
}, [installJob.bytes, installJob.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
return '';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
return (
<Flex gap="2" w="full" alignItems="center">
@@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
<Flex flexDir="column" flex={1}>
<Tooltip label={progressString}>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
value={progressValue ?? 0}
isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
</Tooltip>
</Flex>
<Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
</Box>
<Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
{(installJob.status === 'downloading' ||
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton
isRound={true}
size="xs"

View File

@@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanModelsResults';
import { ScanModelsResults } from './ScanFolderResults';
export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
const handleSubmitScan = useCallback(async () => {
_scanModels({ scan_path: scanPath }).catch((error) => {
const scanFolder = useCallback(async () => {
_scanFolder({ scan_path: scanPath }).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_scanModels, scanPath]);
}, [_scanFolder, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setScanPath(e.target.value);
@@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
<Input value={scanPath} onChange={handleSetScanPath} />
</Flex>
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
{t('modelManager.scanFolder')}
</Button>
</Flex>

View File

@@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanModelResultItem';
import { ScanModelResultItem } from './ScanFolderResultItem';
type ScanModelResultsProps = {
results: ScanFolderResponse;

View File

@@ -1,12 +1,11 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
export const ImportModels = () => {
export const InstallModels = () => {
const { t } = useTranslation();
return (
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
@@ -17,15 +16,11 @@ export const ImportModels = () => {
<Tabs variant="collapse" height="100%">
<TabList>
<Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.advanced')}</Tab>
<Tab>{t('modelManager.scan')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<SimpleImport />
</TabPanel>
<TabPanel height="100%">
<AdvancedImport />
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
@@ -34,7 +29,7 @@ export const ImportModels = () => {
</Tabs>
</Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
<ImportQueue />
<ModelInstallQueue />
</Box>
</Flex>
);

View File

@@ -5,7 +5,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5';
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
const MODEL_TYPE_LABELS: { [key: string]: string } = {
main: 'Main',
lora: 'LoRA',
embedding: 'Textual Inversion',

View File

@@ -1,14 +1,14 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ImportModels } from './ImportModels';
import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model';
export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box>
);
};

View File

@@ -5,7 +5,7 @@ import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
@@ -24,7 +24,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);

View File

@@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
@@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const [updateModel, { isLoading }] = useUpdateModelMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
@@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
};
editModelMetadata({
updateModel({
key: selectedModelKey,
body: { default_settings: body },
})
@@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
}
});
},
[selectedModelKey, dispatch, editModelMetadata, t]
[selectedModelKey, dispatch, updateModel, t]
);
return (

View File

@@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
@@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const BaseModelSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'base' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'true', label: 'True' },
{ value: 'false', label: 'False' },
];
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(BooleanSelect);

View File

@@ -1,47 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return [
{ value: 'lycoris', label: 'LyCORIS' },
{ value: 'diffusers', label: 'Diffusers' },
];
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },
{ value: 'embedding_folder', label: 'Embedding Folder' },
];
} else if (modelType === 'ip_adapter') {
return [{ value: 'invokeai', label: 'invokeai' }];
} else {
return [
{ value: 'diffusers', label: 'Diffusers' },
{ value: 'checkpoint', label: 'Checkpoint' },
];
}
}, [type, formState.defaultValues?.type]);
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelFormatSelect);

View File

@@ -1,32 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
] as const;
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelTypeSelect);

View File

@@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' },
@@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
{ value: 'depth', label: 'Depth' },
];
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const ModelVariantSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'variant' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
@@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
{ value: 'sample', label: 'sample' },
];
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const PredictionTypeSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'prediction_type' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'fp16', label: 'fp16' },
{ value: 'fp32', label: 'fp32' },
];
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(RepoVariantSelect);

View File

@@ -2,16 +2,16 @@ import { Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
return (
<>
<Flex flexDir="column" height="full" gap="3">
<DataViewer label="metadata" data={metadata || {}} />
<DataViewer label="metadata" data={data?.source_api_response || {}} />
</Flex>
</>
);

View File

@@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { useConvertModelMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps {
@@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [convertModel, { isLoading }] = useConvertModelMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const modelConvertHandler = useCallback(() => {

View File

@@ -1,5 +1,6 @@
import {
Button,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
@@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect';
import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect';
export const ModelEdit = () => {
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const { t } = useTranslation();
// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;
// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }
// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
watch,
} = useForm<UpdateModelArg['body']>({
defaultValues: {
...data,
@@ -86,10 +48,7 @@ export const ModelEdit = () => {
mode: 'onChange',
});
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
(values) => {
if (!data?.key) {
return;
@@ -143,33 +102,31 @@ export const ModelEdit = () => {
return (
<Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
size="lg"
/>
<Flex gap={2}>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
</Flex>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={3} mt="4">
<Flex>
@@ -184,76 +141,22 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
<BaseModelSelect control={control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
{data.type === 'main' && data.format === 'checkpoint' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
)}

View File

@@ -91,26 +91,19 @@ export const ModelView = () => {
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
<Flex gap={2}>
{modelData.format === 'diffusers' && modelData.repo_variant && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</>
)}
</Flex>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>