diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index b9f75ac30f..7c9f820729 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -340,6 +340,7 @@ "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", + "loraModels": "LoRAs", "safetensorModels": "SafeTensors", "modelAdded": "Model Added", "modelUpdated": "Model Updated", diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index ca52bfb95a..dd0e738eeb 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -1,3 +1,5 @@ +import { components } from 'services/api/schema'; + export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x', 'sd-2': 'Stable Diffusion 2.x', @@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = { 'sdxl-refiner': 'Stable Diffusion XL Refiner', }; +export const MODEL_TYPE_SHORT_MAP = { + 'sd-1': 'SD1', + 'sd-2': 'SD2', + sdxl: 'SDXL', + 'sdxl-refiner': 'SDXLR', +}; + export const clipSkipMap = { 'sd-1': { maxClip: 12, @@ -23,3 +32,12 @@ export const clipSkipMap = { markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], }, }; + +type LoRAModelFormatMap = { + [key in components['schemas']['LoRAModelFormat']]: string; +}; + +export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = { + lycoris: 'LyCORIS', + diffusers: 'Diffusers', +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index 87eb918564..754c8822d1 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react'; import { useState } from 'react'; import { MainModelConfigEntity, + DiffusersModelConfigEntity, + LoRAModelConfigEntity, useGetMainModelsQuery, + useGetLoRAModelsQuery, } from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; +import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; import { ALL_BASE_MODELS } from 'services/api/constants'; export default function ModelManagerPanel() { const [selectedModelId, setSelectedModelId] = useState(); - const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, { + const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ - model: selectedModelId ? data?.entities[selectedModelId] : undefined, + mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined, }), }); + const { loraModel } = useGetLoRAModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined, + }), + }); + + const model = mainModel ? mainModel : loraModel; return ( @@ -30,7 +41,7 @@ export default function ModelManagerPanel() { } type ModelEditProps = { - model: MainModelConfigEntity | undefined; + model: MainModelConfigEntity | LoRAModelConfigEntity | undefined; }; const ModelEdit = (props: ModelEditProps) => { @@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => { } if (model?.model_format === 'diffusers') { - return ; + return ( + + ); + } + + if (model?.model_type === 'lora') { + return ; } return ( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx new file mode 100644 index 0000000000..b1c6900f74 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -0,0 +1,82 @@ +import { Divider, Flex, Text } from '@chakra-ui/react'; +import { useForm } from '@mantine/form'; +import IAIMantineTextInput from 'common/components/IAIMantineInput'; +import { + LORA_MODEL_FORMAT_MAP, + MODEL_TYPE_MAP, +} from 'features/parameters/types/constants'; +import { useTranslation } from 'react-i18next'; +import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import { LoRAModelConfig } from 'services/api/types'; +import BaseModelSelect from '../shared/BaseModelSelect'; + +type LoRAModelEditProps = { + model: LoRAModelConfigEntity; +}; + +export default function LoRAModelEdit(props: LoRAModelEditProps) { + const { model } = props; + + const { t } = useTranslation(); + + const loraEditForm = useForm({ + initialValues: { + model_name: model.model_name ? model.model_name : '', + base_model: model.base_model, + model_type: 'lora', + path: model.path ? model.path : '', + description: model.description ? model.description : '', + model_format: model.model_format, + }, + validate: { + path: (value) => + value.trim().length === 0 ? 'Must provide a path' : null, + }, + }); + + return ( + + + + {model.model_name} + + + {MODEL_TYPE_MAP[model.base_model]} Model ⋅{' '} + {LORA_MODEL_FORMAT_MAP[model.model_format]} format + + + + +
+ + + + + + + {t('Editing LoRA model metadata is not yet supported.')} + + +
+
+ ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index f3d0eae495..e4d8a7b15a 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -9,6 +9,8 @@ import { useTranslation } from 'react-i18next'; import { MainModelConfigEntity, useGetMainModelsQuery, + useGetLoRAModelsQuery, + LoRAModelConfigEntity, } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; import { ALL_BASE_MODELS } from 'services/api/constants'; @@ -20,22 +22,42 @@ type ModelListProps = { type ModelFormat = 'images' | 'checkpoint' | 'diffusers'; +type ModelType = 'main' | 'lora'; + +type CombinedModelFormat = ModelFormat | 'lora'; + const ModelList = (props: ModelListProps) => { const { selectedModelId, setSelectedModelId } = props; const { t } = useTranslation(); const [nameFilter, setNameFilter] = useState(''); const [modelFormatFilter, setModelFormatFilter] = - useState('images'); + useState('images'); const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ - filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), + filteredDiffusersModels: modelsFilter( + data, + 'main', + 'diffusers', + nameFilter + ), }), }); const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ - filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), + filteredCheckpointModels: modelsFilter( + data, + 'main', + 'checkpoint', + nameFilter + ), + }), + }); + + const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), }), }); @@ -68,6 +90,13 @@ const ModelList = (props: ModelListProps) => { > {t('modelManager.checkpointModels')} + setModelFormatFilter('lora')} + isChecked={modelFormatFilter === 'lora'} + > + {t('modelManager.loraModels')} + {
)} + {['images', 'lora'].includes(modelFormatFilter) && + filteredLoraModels.length > 0 && ( + + + + LoRAs + + {filteredLoraModels.map((model) => ( + + ))} + + + )} @@ -126,12 +173,13 @@ const ModelList = (props: ModelListProps) => { export default ModelList; -const modelsFilter = ( - data: EntityState | undefined, - model_format: ModelFormat, +const modelsFilter = ( + data: EntityState | undefined, + model_type: ModelType, + model_format: ModelFormat | undefined, nameFilter: string ) => { - const filteredModels: MainModelConfigEntity[] = []; + const filteredModels: T[] = []; forEach(data?.entities, (model) => { if (!model) { return; @@ -141,9 +189,11 @@ const modelsFilter = ( .toLowerCase() .includes(nameFilter.toLowerCase()); - const matchesFormat = model.model_format === model_format; + const matchesFormat = + model_format === undefined || model.model_format === model_format; + const matchesType = model.model_type === model_type; - if (matchesFilter && matchesFormat) { + if (matchesFilter && matchesFormat && matchesType) { filteredModels.push(model); } }); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 7f4fb0c736..9380eed688 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; import { MainModelConfigEntity, + LoRAModelConfigEntity, useDeleteMainModelsMutation, + useDeleteLoRAModelsMutation, } from 'services/api/endpoints/models'; type ModelListItemProps = { - model: MainModelConfigEntity; + model: MainModelConfigEntity | LoRAModelConfigEntity; isSelected: boolean; setSelectedModelId: (v: string | undefined) => void; }; -const modelBaseTypeMap = { - 'sd-1': 'SD1', - 'sd-2': 'SD2', - sdxl: 'SDXL', - 'sdxl-refiner': 'SDXLR', -}; - export default function ModelListItem(props: ModelListItemProps) { const isBusy = useAppSelector(selectIsBusy); const { t } = useTranslation(); const dispatch = useAppDispatch(); const [deleteMainModel] = useDeleteMainModelsMutation(); + const [deleteLoRAModel] = useDeleteLoRAModelsMutation(); const { model, isSelected, setSelectedModelId } = props; @@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) { }, [model.id, setSelectedModelId]); const handleModelDelete = useCallback(() => { - deleteMainModel(model) + const method = { main: deleteMainModel, lora: deleteLoRAModel }[ + model.model_type + ]; + method(model) .unwrap() .then((_) => { dispatch( @@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) { title: `${t('modelManager.modelDeleteFailed')}: ${ model.model_name }`, - status: 'success', + status: 'error', }) ) ); } }); setSelectedModelId(undefined); - }, [deleteMainModel, model, setSelectedModelId, dispatch, t]); + }, [ + deleteMainModel, + deleteLoRAModel, + model, + setSelectedModelId, + dispatch, + t, + ]); return ( @@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) { { - modelBaseTypeMap[ - model.base_model as keyof typeof modelBaseTypeMap + MODEL_TYPE_SHORT_MAP[ + model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP ] } diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 3d0013a62c..aa93be62b5 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -62,6 +62,10 @@ type DeleteMainModelArg = { type DeleteMainModelResponse = void; +type DeleteLoRAModelArg = DeleteMainModelArg; + +type DeleteLoRAModelResponse = void; + type ConvertMainModelArg = { base_model: BaseModelType; model_name: string; @@ -320,6 +324,18 @@ export const modelsApi = api.injectEndpoints({ ); }, }), + deleteLoRAModels: build.mutation< + DeleteLoRAModelResponse, + DeleteLoRAModelArg + >({ + query: ({ base_model, model_name }) => { + return { + url: `models/${base_model}/lora/${model_name}`, + method: 'DELETE', + }; + }, + invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], + }), getControlNetModels: build.query< EntityState, void @@ -467,6 +483,7 @@ export const { useAddMainModelsMutation, useConvertMainModelsMutation, useMergeMainModelsMutation, + useDeleteLoRAModelsMutation, useSyncModelsMutation, useGetModelsInFolderQuery, useGetCheckpointConfigsQuery,