import IAIButton from 'common/components/IAIButton'; import IAICheckbox from 'common/components/IAICheckbox'; import IAIIconButton from 'common/components/IAIIconButton'; import React from 'react'; import { Box, Flex, FormControl, HStack, Radio, RadioGroup, Text, VStack, } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { systemSelector } from 'features/system/store/systemSelectors'; import { useTranslation } from 'react-i18next'; import { FaPlus } from 'react-icons/fa'; import { MdFindInPage } from 'react-icons/md'; import { addNewModel, searchForModels } from 'app/socketio/actions'; import { setFoundModels, setSearchFolder, } from 'features/system/store/systemSlice'; import { setShouldShowExistingModelsInSearch } from 'features/ui/store/uiSlice'; import type { FoundModel } from 'app/invokeai'; import type { RootState } from 'app/store'; import IAIInput from 'common/components/IAIInput'; import { Field, Formik } from 'formik'; import { forEach, remove } from 'lodash'; import type { ChangeEvent, ReactNode } from 'react'; import { BiReset } from 'react-icons/bi'; const existingModelsSelector = createSelector([systemSelector], (system) => { const { model_list } = system; const existingModels: string[] = []; forEach(model_list, (value) => { existingModels.push(value.weights); }); return existingModels; }); function ModelExistsTag() { const { t } = useTranslation(); return ( {t('modelManager.modelExists')} ); } interface SearchModelEntry { model: FoundModel; modelsToAdd: string[]; setModelsToAdd: React.Dispatch>; } function SearchModelEntry({ model, modelsToAdd, setModelsToAdd, }: SearchModelEntry) { const existingModels = useAppSelector(existingModelsSelector); const foundModelsChangeHandler = (e: ChangeEvent) => { if (!modelsToAdd.includes(e.target.value)) { setModelsToAdd([...modelsToAdd, e.target.value]); } else { setModelsToAdd(remove(modelsToAdd, (v) => v !== e.target.value)); } }; return ( {existingModels.includes(model.location) ? : null} {model.name} {model.location} > } isChecked={modelsToAdd.includes(model.name)} isDisabled={existingModels.includes(model.location)} onChange={foundModelsChangeHandler} padding="1rem" backgroundColor="var(--background-color)" borderRadius="0.5rem" _checked={{ backgroundColor: 'var(--accent-color)', color: 'var(--text-color)', }} _disabled={{ backgroundColor: 'var(--background-color-secondary)', }} > ); } export default function SearchModels() { const dispatch = useAppDispatch(); const { t } = useTranslation(); const searchFolder = useAppSelector( (state: RootState) => state.system.searchFolder ); const foundModels = useAppSelector( (state: RootState) => state.system.foundModels ); const existingModels = useAppSelector(existingModelsSelector); const shouldShowExistingModelsInSearch = useAppSelector( (state: RootState) => state.ui.shouldShowExistingModelsInSearch ); const isProcessing = useAppSelector( (state: RootState) => state.system.isProcessing ); const [modelsToAdd, setModelsToAdd] = React.useState([]); const [modelType, setModelType] = React.useState('v1'); const [pathToConfig, setPathToConfig] = React.useState(''); const resetSearchModelHandler = () => { dispatch(setSearchFolder(null)); dispatch(setFoundModels(null)); setModelsToAdd([]); }; const findModelsHandler = (values: { checkpointFolder: string }) => { dispatch(searchForModels(values.checkpointFolder)); }; const addAllToSelected = () => { setModelsToAdd([]); if (foundModels) { foundModels.forEach((model) => { if (!existingModels.includes(model.location)) { setModelsToAdd((currentModels) => { return [...currentModels, model.name]; }); } }); } }; const removeAllFromSelected = () => { setModelsToAdd([]); }; const addSelectedModels = () => { const modelsToBeAdded = foundModels?.filter((foundModel) => modelsToAdd.includes(foundModel.name) ); const configFiles = { v1: 'configs/stable-diffusion/v1-inference.yaml', v2: 'configs/stable-diffusion/v2-inference-v.yaml', inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml', custom: pathToConfig, }; modelsToBeAdded?.forEach((model) => { const modelFormat = { name: model.name, description: '', config: configFiles[modelType as keyof typeof configFiles], weights: model.location, vae: '', width: 512, height: 512, default: false, format: 'ckpt', }; dispatch(addNewModel(modelFormat)); }); setModelsToAdd([]); }; const renderFoundModels = () => { const newFoundModels: ReactNode[] = []; const existingFoundModels: ReactNode[] = []; if (foundModels) { foundModels.forEach((model, index) => { if (existingModels.includes(model.location)) { existingFoundModels.push( ); } else { newFoundModels.push( ); } }); } return ( <> {newFoundModels} {shouldShowExistingModelsInSearch && existingFoundModels} > ); }; return ( <> {searchFolder ? ( {t('modelManager.checkpointFolder')} {searchFolder} } position="absolute" right={16} fontSize={18} disabled={isProcessing} onClick={() => dispatch(searchForModels(searchFolder))} /> } position="absolute" right={5} onClick={resetSearchModelHandler} /> ) : ( { findModelsHandler(values); }} > {({ handleSubmit }) => ( } aria-label={t('modelManager.findModels')} tooltip={t('modelManager.findModels')} type="submit" disabled={isProcessing} /> )} )} {foundModels && ( {t('modelManager.modelsFound')}: {foundModels.length} {t('modelManager.selected')}: {modelsToAdd.length} {t('modelManager.selectAll')} {t('modelManager.deselectAll')} dispatch( setShouldShowExistingModelsInSearch( !shouldShowExistingModelsInSearch ) ) } /> 0 ? 'var(--accent-color) !important' : '' } > {t('modelManager.addSelected')} Pick Model Type: setModelType(v)} defaultValue="v1" name="model_type" > {t('modelManager.v1')} {t('modelManager.v2')} {t('modelManager.inpainting')} {t('modelManager.customConfig')} {modelType === 'custom' && ( {t('modelManager.pathToCustomConfig')} { if (e.target.value !== '') setPathToConfig(e.target.value); }} width="42.5rem" /> )} {foundModels.length > 0 ? ( modelsToAdd.length === 0 && ( {t('modelManager.selectAndAdd')} ) ) : ( {t('modelManager.noModelsFound')} )} {renderFoundModels()} )} > ); }
{model.name}
{model.location}
{t('modelManager.checkpointFolder')}
{searchFolder}
{t('modelManager.modelsFound')}: {foundModels.length}
{t('modelManager.selected')}: {modelsToAdd.length}