From 28373dbb98e201c4964329ffda5f89974f73695e Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 17:36:23 +1200 Subject: [PATCH] cleanup: Updated model slice names to be more descriptive Basically updated all slices to be more descriptive in their names. Did so in order to make sure theres good naming scheme available for secondary models. --- .../enhancers/reduxRemember/unserialize.ts | 8 +-- .../listeners/socketio/socketConnected.ts | 7 ++- invokeai/frontend/web/src/app/store/store.ts | 8 +-- .../fields/ModelInputFieldComponent.tsx | 6 +- .../system/components/ModelSelect.tsx | 9 ++- .../features/system/store/modelSelectors.ts | 37 ++++++------ .../system/store/models/sd1ModelSlice.ts | 53 ----------------- .../store/models/sd1PipelineModelSlice.ts | 57 +++++++++++++++++++ .../system/store/models/sd2ModelSlice.ts | 53 ----------------- .../store/models/sd2PipelineModelSlice.ts | 57 +++++++++++++++++++ .../system/store/modelsPersistDenylist.ts | 8 +-- .../frontend/web/src/services/thunks/model.ts | 7 ++- 12 files changed, 164 insertions(+), 146 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts create mode 100644 invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts delete mode 100644 invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts create mode 100644 invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 93cc19f832..dc1c25c015 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,8 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice'; -import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice'; +import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice'; +import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -22,8 +22,8 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - sd1models: sd1InitialModelsState, - sd2models: sd2InitialModelsState, + sd1pipelinemodels: sd1InitialPipelineModelsState, + sd2pipelinemodels: sd2InitialPipelineModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 2ce1ba45e6..14263b643b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -15,7 +15,8 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { sd1models, sd2models, nodes, config, images } = getState(); + const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } = + getState(); const { disabledTabs } = config; @@ -23,11 +24,11 @@ export const addSocketConnectedEventListener = () => { dispatch(receivedPageOfImages()); } - if (!sd1models.ids.length) { + if (!sd1pipelinemodels.ids.length) { dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' })); } - if (!sd2models.ids.length) { + if (!sd2pipelinemodels.ids.length) { dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' })); } diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 7ff3fb8dc5..06aa6d3535 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -29,8 +29,8 @@ import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { stateSanitizer } from './middleware/devtools/stateSanitizer'; // Model Reducers -import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice'; -import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice'; +import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice'; +import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; @@ -41,8 +41,8 @@ const allReducers = { gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - sd1models: sd1ModelReducer, - sd2models: sd2ModelReducer, + sd1pipelinemodels: sd1PipelineModelReducer, + sd2pipelinemodels: sd2PipelineModelReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 3842e8da3a..480c8591bb 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -17,7 +17,7 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); - const { sd1ModelDropDownData, sd2ModelDropdownData } = + const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } = useAppSelector(modelSelector); const handleValueChanged = (e: ChangeEvent) => { @@ -33,8 +33,8 @@ const ModelInputFieldComponent = ( return ( ); }; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 43de144991..813bd9fb70 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -20,8 +20,11 @@ const MODEL_LOADER_MAP = { const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { selectedModel, sd1ModelDropDownData, sd2ModelDropdownData } = - useAppSelector(modelSelector); + const { + selectedModel, + sd1PipelineModelDropDownData, + sd2PipelineModelDropdownData, + } = useAppSelector(modelSelector); useEffect(() => { if (selectedModel) @@ -48,7 +51,7 @@ const ModelSelect = () => { label={t('modelManager.model')} value={selectedModel?.name ?? ''} placeholder="Pick one" - data={sd1ModelDropDownData.concat(sd2ModelDropdownData)} + data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)} onChange={handleChangeModel} /> ); diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts index 6e101da5f5..b63c6d256c 100644 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts @@ -3,26 +3,30 @@ import { RootState } from 'app/store/store'; import { IAISelectDataType } from 'common/components/IAIMantineSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { isEqual } from 'lodash-es'; + import { - selectAllSD1Models, - selectByIdSD1Models, -} from './models/sd1ModelSlice'; + selectAllSD1PipelineModels, + selectByIdSD1PipelineModels, +} from './models/sd1PipelineModelSlice'; + import { - selectAllSD2Models, - selectByIdSD2Models, -} from './models/sd2ModelSlice'; + selectAllSD2PipelineModels, + selectByIdSD2PipelineModels, +} from './models/sd2PipelineModelSlice'; export const modelSelector = createSelector( [(state: RootState) => state, generationSelector], (state, generation) => { - let selectedModel = selectByIdSD1Models(state, generation.model); + let selectedModel = selectByIdSD1PipelineModels(state, generation.model); if (selectedModel === undefined) - selectedModel = selectByIdSD2Models(state, generation.model); + selectedModel = selectByIdSD2PipelineModels(state, generation.model); - const sd1Models = selectAllSD1Models(state); - const sd2Models = selectAllSD2Models(state); + const sd1PipelineModels = selectAllSD1PipelineModels(state); + const sd2PipelineModels = selectAllSD2PipelineModels(state); - const sd1ModelDropDownData = selectAllSD1Models(state) + const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels); + + const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state) .map((m) => ({ value: m.name, label: m.name, @@ -30,7 +34,7 @@ export const modelSelector = createSelector( })) .sort((a, b) => a.label.localeCompare(b.label)); - const sd2ModelDropdownData = selectAllSD2Models(state) + const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state) .map((m) => ({ value: m.name, label: m.name, @@ -40,10 +44,11 @@ export const modelSelector = createSelector( return { selectedModel, - sd1Models, - sd2Models, - sd1ModelDropDownData, - sd2ModelDropdownData, + allPipelineModels, + sd1PipelineModels, + sd2PipelineModels, + sd1PipelineModelDropDownData, + sd2PipelineModelDropdownData, }; }, { diff --git a/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts deleted file mode 100644 index 9f62fde264..0000000000 --- a/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - StableDiffusion1ModelCheckpointConfig, - StableDiffusion1ModelDiffusersConfig, -} from 'services/api'; - -import { getModels } from 'services/thunks/model'; - -export type SD1ModelType = ( - | StableDiffusion1ModelCheckpointConfig - | StableDiffusion1ModelDiffusersConfig -) & { - name: string; -}; - -export const sd1ModelsAdapter = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState(); - -export type SD1ModelState = typeof sd1InitialModelsState; - -export const sd1ModelsSlice = createSlice({ - name: 'sd1models', - initialState: sd1InitialModelsState, - reducers: { - modelAdded: sd1ModelsAdapter.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(getModels.fulfilled, (state, action) => { - if (action.meta.arg.baseModel !== 'sd-1') return; - sd1ModelsAdapter.setAll(state, action.payload); - }); - }, -}); - -export const { - selectAll: selectAllSD1Models, - selectById: selectByIdSD1Models, - selectEntities: selectEntitiesSD1Models, - selectIds: selectIdsSD1Models, - selectTotal: selectTotalSD1Models, -} = sd1ModelsAdapter.getSelectors((state) => state.sd1models); - -export const { modelAdded } = sd1ModelsSlice.actions; - -export default sd1ModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts new file mode 100644 index 0000000000..5755b14886 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts @@ -0,0 +1,57 @@ +import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { + StableDiffusion1ModelCheckpointConfig, + StableDiffusion1ModelDiffusersConfig, +} from 'services/api'; + +import { getModels } from 'services/thunks/model'; + +export type SD1PipelineModelType = ( + | StableDiffusion1ModelCheckpointConfig + | StableDiffusion1ModelDiffusersConfig +) & { + name: string; +}; + +export const sd1PipelineModelsAdapter = + createEntityAdapter({ + selectId: (model) => model.name, + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); + +export const sd1InitialPipelineModelsState = + sd1PipelineModelsAdapter.getInitialState(); + +export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState; + +export const sd1PipelineModelsSlice = createSlice({ + name: 'sd1models', + initialState: sd1InitialPipelineModelsState, + reducers: { + modelAdded: sd1PipelineModelsAdapter.upsertOne, + }, + extraReducers(builder) { + /** + * Received Models - FULFILLED + */ + builder.addCase(getModels.fulfilled, (state, action) => { + if (action.meta.arg.baseModel !== 'sd-1') return; + sd1PipelineModelsAdapter.setAll(state, action.payload); + }); + }, +}); + +export const { + selectAll: selectAllSD1PipelineModels, + selectById: selectByIdSD1PipelineModels, + selectEntities: selectEntitiesSD1PipelineModels, + selectIds: selectIdsSD1PipelineModels, + selectTotal: selectTotalSD1PipelineModels, +} = sd1PipelineModelsAdapter.getSelectors( + (state) => state.sd1pipelinemodels +); + +export const { modelAdded } = sd1PipelineModelsSlice.actions; + +export default sd1PipelineModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts deleted file mode 100644 index e8e1f5bedf..0000000000 --- a/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - StableDiffusion2ModelCheckpointConfig, - StableDiffusion2ModelDiffusersConfig, -} from 'services/api'; - -import { getModels } from 'services/thunks/model'; - -export type SD2ModelType = ( - | StableDiffusion2ModelCheckpointConfig - | StableDiffusion2ModelDiffusersConfig -) & { - name: string; -}; - -export const sd2ModelsAdapater = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState(); - -export type SD2ModelState = typeof sd2InitialModelsState; - -export const sd2ModelsSlice = createSlice({ - name: 'sd2models', - initialState: sd2InitialModelsState, - reducers: { - modelAdded: sd2ModelsAdapater.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(getModels.fulfilled, (state, action) => { - if (action.meta.arg.baseModel !== 'sd-2') return; - sd2ModelsAdapater.setAll(state, action.payload); - }); - }, -}); - -export const { - selectAll: selectAllSD2Models, - selectById: selectByIdSD2Models, - selectEntities: selectEntitiesSD2Models, - selectIds: selectIdsSD2Models, - selectTotal: selectTotalSD2Models, -} = sd2ModelsAdapater.getSelectors((state) => state.sd2models); - -export const { modelAdded } = sd2ModelsSlice.actions; - -export default sd2ModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts new file mode 100644 index 0000000000..0c307e23cc --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts @@ -0,0 +1,57 @@ +import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { + StableDiffusion2ModelCheckpointConfig, + StableDiffusion2ModelDiffusersConfig, +} from 'services/api'; + +import { getModels } from 'services/thunks/model'; + +export type SD2PipelineModelType = ( + | StableDiffusion2ModelCheckpointConfig + | StableDiffusion2ModelDiffusersConfig +) & { + name: string; +}; + +export const sd2PipelineModelsAdapater = + createEntityAdapter({ + selectId: (model) => model.name, + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); + +export const sd2InitialPipelineModelsState = + sd2PipelineModelsAdapater.getInitialState(); + +export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState; + +export const sd2PipelineModelsSlice = createSlice({ + name: 'sd2models', + initialState: sd2InitialPipelineModelsState, + reducers: { + modelAdded: sd2PipelineModelsAdapater.upsertOne, + }, + extraReducers(builder) { + /** + * Received Models - FULFILLED + */ + builder.addCase(getModels.fulfilled, (state, action) => { + if (action.meta.arg.baseModel !== 'sd-2') return; + sd2PipelineModelsAdapater.setAll(state, action.payload); + }); + }, +}); + +export const { + selectAll: selectAllSD2PipelineModels, + selectById: selectByIdSD2PipelineModels, + selectEntities: selectEntitiesSD2PipelineModels, + selectIds: selectIdsSD2PipelineModels, + selectTotal: selectTotalSD2PipelineModels, +} = sd2PipelineModelsAdapater.getSelectors( + (state) => state.sd2pipelinemodels +); + +export const { modelAdded } = sd2PipelineModelsSlice.actions; + +export default sd2PipelineModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts index 7b0d78d37e..417a399cf2 100644 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts @@ -1,9 +1,9 @@ -import { SD1ModelState } from './models/sd1ModelSlice'; -import { SD2ModelState } from './models/sd2ModelSlice'; +import { SD1PipelineModelState } from './models/sd1PipelineModelSlice'; +import { SD2PipelineModelState } from './models/sd2PipelineModelSlice'; /** * Models slice persist denylist */ export const modelsPersistDenylist: - | (keyof SD1ModelState)[] - | (keyof SD2ModelState)[] = ['entities', 'ids']; + | (keyof SD1PipelineModelState)[] + | (keyof SD2PipelineModelState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts index 4d134439f7..039748fa3f 100644 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ b/invokeai/frontend/web/src/services/thunks/model.ts @@ -1,6 +1,7 @@ import { log } from 'app/logging/useLogger'; import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice'; +import { SD1PipelineModelType } from 'features/system/store/models/sd1PipelineModelSlice'; +import { SD2PipelineModelType } from 'features/system/store/models/sd2PipelineModelSlice'; import { reduce, size } from 'lodash-es'; import { BaseModelType, ModelType, ModelsService } from 'services/api'; @@ -30,7 +31,7 @@ export const getModels = createAppAsyncThunk( modelsAccumulator[modelName] = { ...model, name: modelName }; return modelsAccumulator; }, - {} as Record + {} as Record ); } @@ -41,7 +42,7 @@ export const getModels = createAppAsyncThunk( modelsAccumulator[modelName] = { ...model, name: modelName }; return modelsAccumulator; }, - {} as Record + {} as Record ); }