feat: Initial port of Model Manager to new tab

This commit is contained in:
blessedcoolant
2023-06-26 23:09:16 +12:00
committed by psychedelicious
parent 9e35643911
commit 2ad5a4ea46
16 changed files with 149 additions and 131 deletions

View File

@@ -1,7 +1,81 @@
import { memo } from 'react';
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
import i18n from 'i18n';
import { ReactNode, memo } from 'react';
import AddModelsPanel from './subpanels/AddModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel';
type ModelManagerTabName = 'modelmanager' | 'add_models' | 'merge_models';
type ModelManagerTabInfo = {
id: ModelManagerTabName;
label: string;
content: ReactNode;
};
const modelManagerTabs: ModelManagerTabInfo[] = [
{
id: 'modelmanager',
label: i18n.t('modelManager.modelManager'),
content: <ModelManagerPanel />,
},
{
id: 'add_models',
label: i18n.t('modelManager.addModel'),
content: <AddModelsPanel />,
},
{
id: 'merge_models',
label: i18n.t('modelManager.mergeModels'),
content: <MergeModelsPanel />,
},
];
const ModelManagerTab = () => {
return 'Model Manager';
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: 'base.200',
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: 'accent.700',
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
return (
<Tabs
isLazy
variant="invokeAI"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
>
{renderTabsList()}
{renderTabPanels()}
</Tabs>
);
};
export default memo(ModelManagerTab);

View File

@@ -0,0 +1,10 @@
import { Flex } from '@chakra-ui/react';
import AddModel from 'features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModel';
export default function AddModelsPanel() {
return (
<Flex>
<AddModel />
</Flex>
);
}

View File

@@ -0,0 +1,337 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import React from 'react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import SearchModels from './SearchModels';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function AddCheckpointModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeModelConfigProps = {
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
format: 'ckpt',
default: false,
};
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(addNewModel(values));
dispatch(setAddNewModelUIOption(null));
};
const [addManually, setAddmanually] = React.useState<boolean>(false);
return (
<VStack gap={2} alignItems="flex-start">
<Flex columnGap={4}>
<IAISimpleCheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAISimpleCheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);
}

View File

@@ -0,0 +1,270 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Text,
VStack,
} from '@chakra-ui/react';
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeDiffusersModelConfigProps = {
name: '',
description: '',
repo_id: '',
path: '',
format: 'diffusers',
default: false,
vae: {
repo_id: '',
path: '',
},
};
const addModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToAdd = values;
if (values.path === '') delete diffusersModelToAdd.path;
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
dispatch(addNewModel(diffusersModelToAdd));
dispatch(setAddNewModelUIOption(null));
};
return (
<Flex>
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="2xl"
isRequired
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="2xl"
isRequired
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersModelLocationDesc')}
</Text>
{/* Path */}
<FormControl isInvalid={!!errors.path && touched.path}>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="path"
name="path"
type="text"
width="2xl"
/>
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
width="2xl"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.repoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersVAELocationDesc')}
</Text>
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
width="2xl"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
width="2xl"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeRepoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
);
}

View File

@@ -0,0 +1,125 @@
import {
Button,
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({
text,
onClick,
}: {
text: string;
onClick?: () => void;
}) {
return (
<Flex
position="relative"
width="50%"
height={40}
justifyContent="center"
alignItems="center"
onClick={onClick}
as={Button}
>
<Text fontWeight="bold">{text}</Text>
</Flex>
);
}
export default function AddModel() {
const { isOpen, onOpen, onClose } = useDisclosure();
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const addModelModalClose = () => {
onClose();
dispatch(setAddNewModelUIOption(null));
};
return (
<>
<IAIButton
aria-label={t('modelManager.addNewModel')}
tooltip={t('modelManager.addNewModel')}
onClick={onOpen}
size="sm"
>
<Flex columnGap={2} alignItems="center">
<FaPlus />
{t('modelManager.addNew')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={addModelModalClose}
size="3xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton />
<ModalBody>
{addNewModelUIOption == null && (
<Flex columnGap={4}>
<AddModelBox
text={t('modelManager.addCheckpointModel')}
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
/>
<AddModelBox
text={t('modelManager.addDiffuserModel')}
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
/>
</Flex>
)}
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
}

View File

@@ -0,0 +1,430 @@
import IAIButton from 'common/components/IAIButton';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIIconButton from 'common/components/IAIIconButton';
import React from 'react';
import {
Badge,
Flex,
FormControl,
HStack,
Radio,
RadioGroup,
Spacer,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { useTranslation } from 'react-i18next';
import { FaSearch, FaTrash } from 'react-icons/fa';
// import { addNewModel, searchForModels } from 'app/socketio/actions';
import {
setFoundModels,
setSearchFolder,
} from 'features/system/store/systemSlice';
import { setShouldShowExistingModelsInSearch } from 'features/ui/store/uiSlice';
import type { FoundModel } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import IAIInput from 'common/components/IAIInput';
import { Field, Formik } from 'formik';
import { forEach, remove } from 'lodash-es';
import type { ChangeEvent, ReactNode } from 'react';
import IAIForm from 'common/components/IAIForm';
const existingModelsSelector = createSelector([systemSelector], (system) => {
const { model_list } = system;
const existingModels: string[] = [];
forEach(model_list, (value) => {
existingModels.push(value.weights);
});
return existingModels;
});
interface SearchModelEntry {
model: FoundModel;
modelsToAdd: string[];
setModelsToAdd: React.Dispatch<React.SetStateAction<string[]>>;
}
function SearchModelEntry({
model,
modelsToAdd,
setModelsToAdd,
}: SearchModelEntry) {
const { t } = useTranslation();
const existingModels = useAppSelector(existingModelsSelector);
const foundModelsChangeHandler = (e: ChangeEvent<HTMLInputElement>) => {
if (!modelsToAdd.includes(e.target.value)) {
setModelsToAdd([...modelsToAdd, e.target.value]);
} else {
setModelsToAdd(remove(modelsToAdd, (v) => v !== e.target.value));
}
};
return (
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4} alignItems="center" justifyContent="space-between">
<IAISimpleCheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAISimpleCheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
</Flex>
);
}
export default function SearchModels() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const searchFolder = useAppSelector(
(state: RootState) => state.system.searchFolder
);
const foundModels = useAppSelector(
(state: RootState) => state.system.foundModels
);
const existingModels = useAppSelector(existingModelsSelector);
const shouldShowExistingModelsInSearch = useAppSelector(
(state: RootState) => state.ui.shouldShowExistingModelsInSearch
);
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const [modelsToAdd, setModelsToAdd] = React.useState<string[]>([]);
const [modelType, setModelType] = React.useState<string>('v1');
const [pathToConfig, setPathToConfig] = React.useState<string>('');
const resetSearchModelHandler = () => {
dispatch(setSearchFolder(null));
dispatch(setFoundModels(null));
setModelsToAdd([]);
};
const findModelsHandler = (values: { checkpointFolder: string }) => {
dispatch(searchForModels(values.checkpointFolder));
};
const addAllToSelected = () => {
setModelsToAdd([]);
if (foundModels) {
foundModels.forEach((model) => {
if (!existingModels.includes(model.location)) {
setModelsToAdd((currentModels) => {
return [...currentModels, model.name];
});
}
});
}
};
const removeAllFromSelected = () => {
setModelsToAdd([]);
};
const addSelectedModels = () => {
const modelsToBeAdded = foundModels?.filter((foundModel) =>
modelsToAdd.includes(foundModel.name)
);
const configFiles = {
v1: 'configs/stable-diffusion/v1-inference.yaml',
v2_base: 'configs/stable-diffusion/v2-inference-v.yaml',
v2_768: 'configs/stable-diffusion/v2-inference-v.yaml',
inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml',
custom: pathToConfig,
};
modelsToBeAdded?.forEach((model) => {
const modelFormat = {
name: model.name,
description: '',
config: configFiles[modelType as keyof typeof configFiles],
weights: model.location,
vae: '',
width: 512,
height: 512,
default: false,
format: 'ckpt',
};
dispatch(addNewModel(modelFormat));
});
setModelsToAdd([]);
};
const renderFoundModels = () => {
const newFoundModels: ReactNode[] = [];
const existingFoundModels: ReactNode[] = [];
if (foundModels) {
foundModels.forEach((model, index) => {
if (existingModels.includes(model.location)) {
existingFoundModels.push(
<SearchModelEntry
key={index}
model={model}
modelsToAdd={modelsToAdd}
setModelsToAdd={setModelsToAdd}
/>
);
} else {
newFoundModels.push(
<SearchModelEntry
key={index}
model={model}
modelsToAdd={modelsToAdd}
setModelsToAdd={setModelsToAdd}
/>
);
}
});
}
return (
<Flex flexDirection="column" rowGap={4}>
{newFoundModels}
{shouldShowExistingModelsInSearch && existingFoundModels}
</Flex>
);
};
return (
<>
{searchFolder ? (
<Flex
sx={{
padding: 4,
gap: 2,
position: 'relative',
borderRadius: 'base',
alignItems: 'center',
w: 'full',
bg: 'base.900',
}}
>
<Flex
sx={{
flexDir: 'column',
gap: 2,
}}
>
<Text
sx={{
fontWeight: 500,
}}
variant="subtext"
>
{t('modelManager.checkpointFolder')}
</Text>
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
</Flex>
<Spacer />
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<FaSearch />}
fontSize={18}
disabled={isProcessing}
onClick={() => dispatch(searchForModels(searchFolder))}
/>
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
onClick={resetSearchModelHandler}
/>
</Flex>
) : (
<Formik
initialValues={{ checkpointFolder: '' }}
onSubmit={(values) => {
findModelsHandler(values);
}}
>
{({ handleSubmit }) => (
<IAIForm onSubmit={handleSubmit} width="100%">
<HStack columnGap={2} alignItems="flex-end">
<FormControl flexGrow={1}>
<Field
as={IAIInput}
id="checkpointFolder"
name="checkpointFolder"
type="text"
size="md"
label={t('modelManager.checkpointFolder')}
/>
</FormControl>
<IAIButton
leftIcon={<FaSearch />}
aria-label={t('modelManager.findModels')}
tooltip={t('modelManager.findModels')}
type="submit"
disabled={isProcessing}
px={8}
>
{t('modelManager.findModels')}
</IAIButton>
</HStack>
</IAIForm>
)}
</Formik>
)}
{foundModels && (
<Flex flexDirection="column" rowGap={4} width="full">
<Flex justifyContent="space-between" alignItems="center">
<p>
{t('modelManager.modelsFound')}: {foundModels.length}
</p>
<p>
{t('modelManager.selected')}: {modelsToAdd.length}
</p>
</Flex>
<Flex columnGap={2} justifyContent="space-between">
<Flex columnGap={2}>
<IAIButton
isDisabled={modelsToAdd.length === foundModels.length}
onClick={addAllToSelected}
>
{t('modelManager.selectAll')}
</IAIButton>
<IAIButton
isDisabled={modelsToAdd.length === 0}
onClick={removeAllFromSelected}
>
{t('modelManager.deselectAll')}
</IAIButton>
<IAISimpleCheckbox
label={t('modelManager.showExisting')}
isChecked={shouldShowExistingModelsInSearch}
onChange={() =>
dispatch(
setShouldShowExistingModelsInSearch(
!shouldShowExistingModelsInSearch
)
)
}
/>
</Flex>
<IAIButton
isDisabled={modelsToAdd.length === 0}
onClick={addSelectedModels}
colorScheme="accent"
>
{t('modelManager.addSelected')}
</IAIButton>
</Flex>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
<Flex gap={4}>
<Text fontWeight={500} variant="subtext">
{t('modelManager.pickModelType')}
</Text>
<RadioGroup
value={modelType}
onChange={(v) => setModelType(v)}
defaultValue="v1"
name="model_type"
>
<Flex gap={4}>
<Radio value="v1">
<Text fontSize="sm">{t('modelManager.v1')}</Text>
</Radio>
<Radio value="v2_base">
<Text fontSize="sm">{t('modelManager.v2_base')}</Text>
</Radio>
<Radio value="v2_768">
<Text fontSize="sm">{t('modelManager.v2_768')}</Text>
</Radio>
<Radio value="inpainting">
<Text fontSize="sm">{t('modelManager.inpainting')}</Text>
</Radio>
<Radio value="custom">
<Text fontSize="sm">{t('modelManager.customConfig')}</Text>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelType === 'custom' && (
<Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.pathToCustomConfig')}
</Text>
<IAIInput
value={pathToConfig}
onChange={(e) => {
if (e.target.value !== '') setPathToConfig(e.target.value);
}}
width="full"
/>
</Flex>
)}
</Flex>
<Flex
flexDirection="column"
maxHeight={72}
overflowY="scroll"
borderRadius="sm"
gap={2}
>
{foundModels.length > 0 ? (
renderFoundModels()
) : (
<Text
fontWeight="500"
padding={2}
borderRadius="sm"
textAlign="center"
variant="subtext"
>
{t('modelManager.noModelsFound')}
</Text>
)}
</Flex>
</Flex>
)}
</>
);
}

View File

@@ -0,0 +1,10 @@
import { Flex } from '@chakra-ui/react';
import MergeModels from 'features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel/MergeModels';
export default function MergeModelsPanel() {
return (
<Flex>
<MergeModels />
</Flex>
);
}

View File

@@ -0,0 +1,313 @@
import {
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Radio,
RadioGroup,
Text,
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
// import { mergeDiffusersModels } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import { diffusersModelsSelector } from 'features/system/store/systemSelectors';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import * as InvokeAI from 'app/types/invokeai';
import IAISlider from 'common/components/IAISlider';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
export default function MergeModels() {
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
const diffusersModels = useAppSelector(diffusersModelsSelector);
const { t } = useTranslation();
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
);
const [modelThree, setModelThree] = useState<string>('none');
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
>('root');
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
useState<string>('');
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter(
(model) => model !== modelTwo && model !== modelThree
);
const modelTwoList = Object.keys(diffusersModels).filter(
(model) => model !== modelOne && model !== modelThree
);
const modelThreeList = [
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
};
return (
<>
<IAIButton onClick={onOpen} size="sm">
<Flex columnGap={2} alignItems="center">
{t('modelManager.mergeModels')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={onClose}
size="4xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent fontFamily="Inter" margin="auto" paddingInlineEnd={4}>
<ModalHeader>{t('modelManager.mergeModels')}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex flexDirection="column" rowGap={4}>
<Flex
sx={{
flexDirection: 'column',
marginBottom: 4,
padding: 4,
borderRadius: 'base',
rowGap: 1,
bg: 'base.900',
}}
>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
<Text fontSize="sm" variant="subtext">
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
label={t('modelManager.modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
/>
<IAISelect
label={t('modelManager.modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
/>
<IAISelect
label={t('modelManager.modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelMergeInterp('weighted_sum');
}
}}
/>
</Flex>
<IAIInput
label={t('modelManager.mergedModelName')}
value={mergedModelName}
onChange={(e) => setMergedModelName(e.target.value)}
/>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<IAISlider
label={t('modelManager.alpha')}
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={(v) => setModelMergeAlpha(v)}
withInput
withReset
handleReset={() => setModelMergeAlpha(0.5)}
withSliderMarks
/>
<Text variant="subtext" fontSize="sm">
{t('modelManager.modelMergeAlphaHelp')}
</Text>
</Flex>
<Flex
sx={{
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
{t('modelManager.interpolationType')}
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v:
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference'
) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">
{t('modelManager.weightedSum')}
</Text>
</Radio>
<Radio value="sigmoid">
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
</Radio>
<Radio value="inv_sigmoid">
<Text fontSize="sm">
{t('modelManager.inverseSigmoid')}
</Text>
</Radio>
</>
) : (
<Radio value="add_difference">
<Tooltip
label={t(
'modelManager.modelMergeInterpAddDifferenceHelp'
)}
>
<Text fontSize="sm">
{t('modelManager.addDifference')}
</Text>
</Tooltip>
</Radio>
)}
</Flex>
</RadioGroup>
</Flex>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<Flex columnGap={4}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.mergedModelSaveLocation')}
</Text>
<RadioGroup
value={modelMergeSaveLocType}
onChange={(v: 'root' | 'custom') =>
setModelMergeSaveLocType(v)
}
>
<Flex columnGap={4}>
<Radio value="root">
<Text fontSize="sm">
{t('modelManager.invokeAIFolder')}
</Text>
</Radio>
<Radio value="custom">
<Text fontSize="sm">{t('modelManager.custom')}</Text>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelMergeSaveLocType === 'custom' && (
<IAIInput
label={t('modelManager.mergedModelCustomSaveLocation')}
value={modelMergeCustomSaveLoc}
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
isChecked={modelMergeForce}
onChange={(e) => setModelMergeForce(e.target.checked)}
fontWeight="500"
/>
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' &&
modelMergeCustomSaveLoc === ''
}
>
{t('modelManager.merge')}
</IAIButton>
</Flex>
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
}

View File

@@ -0,0 +1,44 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useListModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !pipelineModels) return;
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
</Flex>
);
}

View File

@@ -0,0 +1,322 @@
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
Flex,
FormControl,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import type { FieldInputProps, FormikProps } from 'formik';
import ModelConvert from './ModelConvert';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: any;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { modelToEdit, retrievedModel } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] =
useState<InvokeModelConfigProps>({
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
default: false,
model_format: 'ckpt',
});
useEffect(() => {
if (modelToEdit) {
setEditModelFormValues({
name: modelToEdit,
description: retrievedModel?.description,
config: retrievedModel?.config,
weights: retrievedModel?.weights,
vae: retrievedModel?.vae,
width: retrievedModel?.width,
height: retrievedModel?.height,
default: retrievedModel?.default,
model_format: 'ckpt',
});
}
}, [retrievedModel, modelToEdit]);
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(
addNewModel({
...values,
width: Number(values.width),
height: Number(values.height),
})
);
};
return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center" gap={4} justifyContent="space-between">
<Text fontSize="lg" fontWeight="bold">
{modelToEdit}
</Text>
<ModelConvert model={modelToEdit} />
</Flex>
<Flex
flexDirection="column"
maxHeight={window.innerHeight - 270}
overflowY="scroll"
paddingInlineEnd={8}
>
<Formik
enableReinitialize={true}
initialValues={editModelFormValues}
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<IAIFormErrorMessage>
{errors.description}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.configValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<IAIFormErrorMessage>
{errors.weights}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<IAIFormErrorMessage>
{errors.width}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.widthValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<IAIFormErrorMessage>
{errors.height}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.heightValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@@ -0,0 +1,257 @@
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: any;
};
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { retrievedModel, modelToEdit } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] =
useState<InvokeDiffusersModelConfigProps>({
name: '',
description: '',
repo_id: '',
path: '',
vae: { repo_id: '', path: '' },
default: false,
model_format: 'diffusers',
});
useEffect(() => {
setEditModelFormValues({
name: modelToEdit,
description: retrievedModel?.description,
path:
retrievedModel?.path && retrievedModel?.path !== 'None'
? retrievedModel?.path
: '',
repo_id:
retrievedModel?.repo_id && retrievedModel?.repo_id !== 'None'
? retrievedModel?.repo_id
: '',
vae: {
repo_id: retrievedModel?.vae?.repo_id
? retrievedModel?.vae?.repo_id
: '',
path: retrievedModel?.vae?.path ? retrievedModel?.vae?.path : '',
},
default: retrievedModel?.default,
model_format: 'diffusers',
});
}, [retrievedModel, modelToEdit]);
const editModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToEdit = values;
if (values.path === '') delete diffusersModelToEdit.path;
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
dispatch(addNewModel(values));
};
return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.name}
</Text>
</Flex>
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
<Formik
enableReinitialize={true}
initialValues={editModelFormValues}
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<IAIFormErrorMessage>
{errors.description}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Path */}
<FormControl
isInvalid={!!errors.path && touched.path}
isRequired
>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="path"
name="path"
type="text"
width="full"
/>
{!!errors.path && touched.path ? (
<IAIFormErrorMessage>{errors.path}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
width="full"
/>
{!!errors.repo_id && touched.repo_id ? (
<IAIFormErrorMessage>
{errors.repo_id}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.repoIDValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Path */}
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
width="full"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<IAIFormErrorMessage>
{errors.vae?.path}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
width="full"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<IAIFormErrorMessage>
{errors.vae?.repo_id}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeRepoIDValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@@ -0,0 +1,144 @@
import {
Flex,
ListItem,
Radio,
RadioGroup,
Text,
UnorderedList,
Tooltip,
} from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { useState, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
interface ModelConvertProps {
model: string;
}
export default function ModelConvert(props: ModelConvertProps) {
const { model } = props;
const model_list = useAppSelector(
(state: RootState) => state.system.model_list
);
const retrievedModel = model_list[model];
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const isConnected = useAppSelector(
(state: RootState) => state.system.isConnected
);
const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
useEffect(() => {
setSaveLocation('same');
}, [model]);
const modelConvertCancelHandler = () => {
setSaveLocation('same');
};
const modelConvertHandler = () => {
const modelToConvert = {
model_name: model,
save_location: saveLocation,
custom_location:
saveLocation === 'custom' && customSaveLocation !== ''
? customSaveLocation
: null,
};
dispatch(convertToDiffusers(modelToConvert));
};
return (
<IAIAlertDialog
title={`${t('modelManager.convert')} ${model}`}
acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`}
triggerComponent={
<IAIButton
size={'sm'}
aria-label={t('modelManager.convertToDiffusers')}
isDisabled={
retrievedModel.status === 'active' || isProcessing || !isConnected
}
className=" modal-close-btn"
marginInlineEnd={8}
>
🧨 {t('modelManager.convertToDiffusers')}
</IAIButton>
}
motionPreset="slideInBottom"
>
<Flex flexDirection="column" rowGap={4}>
<Text>{t('modelManager.convertToDiffusersHelpText1')}</Text>
<UnorderedList>
<ListItem>{t('modelManager.convertToDiffusersHelpText2')}</ListItem>
<ListItem>{t('modelManager.convertToDiffusersHelpText3')}</ListItem>
<ListItem>{t('modelManager.convertToDiffusersHelpText4')}</ListItem>
<ListItem>{t('modelManager.convertToDiffusersHelpText5')}</ListItem>
</UnorderedList>
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex>
<Flex flexDir="column" gap={4}>
<Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')}
</Text>
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}>
<Flex gap={4}>
<Radio value="same">
<Tooltip label="Save converted model in the same folder">
{t('modelManager.sameFolder')}
</Tooltip>
</Radio>
<Radio value="root">
<Tooltip label="Save converted model in the InvokeAI root folder">
{t('modelManager.invokeRoot')}
</Tooltip>
</Radio>
<Radio value="custom">
<Tooltip label="Save converted model in a custom folder">
{t('modelManager.custom')}
</Tooltip>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')}
</Text>
<IAIInput
value={customSaveLocation}
onChange={(e) => {
if (e.target.value !== '')
setCustomSaveLocation(e.target.value);
}}
width="full"
/>
</Flex>
)}
</Flex>
</IAIAlertDialog>
);
}

View File

@@ -0,0 +1,233 @@
import { Box, Flex, Spinner, Text } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next';
import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react';
import { useListModelsQuery } from 'services/api/endpoints/models';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton
onClick={onClick}
isActive={isActive}
sx={{
_active: {
bg: 'accent.750',
},
}}
size="sm"
>
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'ckpt' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
const { t } = useTranslation();
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
startTransition(() => {
setSearchText(e.target.value);
});
};
const renderModelListItems = useMemo(() => {
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
if (!pipelineModels) return;
const modelList = pipelineModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model].name.toLowerCase().includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model].name}
description={modelList[model].description}
/>
);
if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model].name}
description={modelList[model].description}
/>
);
}
}
if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model].name}
description={modelList[model].description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model].name}
description={modelList[model].description}
/>
);
}
});
return searchText !== '' ? (
isSelectedFilter === 'all' ? (
<Box marginTop={4}>{filteredModelListItemsToRender}</Box>
) : (
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box>
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
</>
)}
{isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [pipelineModels, searchText, t, isSelectedFilter]);
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
/>
<Flex
flexDirection="column"
gap={4}
maxHeight={window.innerHeight - 240}
overflow="scroll"
paddingInlineEnd={4}
>
<Flex columnGap={2}>
<ModelFilterButton
label={t('modelManager.allModels')}
onClick={() => setIsSelectedFilter('all')}
isActive={isSelectedFilter === 'all'}
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
>
<Spinner />
</Flex>
)}
</Flex>
</Flex>
);
};
export default ModelList;

View File

@@ -0,0 +1,121 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIIconButton from 'common/components/IAIIconButton';
import { setOpenModel } from 'features/system/store/systemSlice';
import { useTranslation } from 'react-i18next';
type ModelListItemProps = {
modelKey: string;
name: string;
description: string | undefined;
};
export default function ModelListItem(props: ModelListItemProps) {
const { isProcessing, isConnected } = useAppSelector(
(state: RootState) => state.system
);
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { modelKey, name, description } = props;
const handleChangeModel = () => {
dispatch(requestModelChange(modelKey));
};
const openModelHandler = () => {
dispatch(setOpenModel(modelKey));
};
const handleModelDelete = () => {
dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null));
};
const statusTextColor = () => {
switch (status) {
case 'active':
return 'ok.500';
case 'cached':
return 'warning.500';
case 'not loaded':
return 'inherit';
}
};
return (
<Flex
alignItems="center"
p={2}
borderRadius="base"
sx={
modelKey === openModel
? {
bg: 'accent.750',
_hover: {
bg: 'accent.750',
},
}
: {
_hover: {
bg: 'base.750',
},
}
}
>
<Box onClick={openModelHandler} cursor="pointer">
<Tooltip label={description} hasArrow placement="bottom">
<Text fontWeight="600">{name}</Text>
</Tooltip>
</Box>
<Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center">
<Button
size="sm"
onClick={handleChangeModel}
isDisabled={status === 'active' || isProcessing || !isConnected}
>
{t('modelManager.load')}
</Button>
<IAIIconButton
icon={<EditIcon />}
size="sm"
onClick={openModelHandler}
aria-label={t('accessibility.modifyConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
/>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
</Flex>
);
}