refactor(ui): use zod for all redux state

This commit is contained in:
psychedelicious
2025-07-25 13:13:28 +10:00
parent 6962536b4a
commit aed9b1013e
39 changed files with 488 additions and 299 deletions

View File

@@ -1,4 +1,5 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
@@ -6,13 +7,35 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types';
const selectTileControlNetModelConfig = createSelector(
selectModelConfigsQuery,
selectTileControlNetModel,
(modelConfigs, modelIdentifierField) => {
if (!modelConfigs.data) {
return null;
}
if (!modelIdentifierField) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key);
if (!modelConfig) {
return null;
}
if (!isControlNetModelConfig(modelConfig)) {
return null;
}
return modelConfig;
}
);
const ParamTileControlNetModel = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const tileControlNetModel = useAppSelector(selectTileControlNetModel);
const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig);
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetModels();

View File

@@ -1,21 +1,21 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) =>
createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => {
const { upscaleModel, scale } = upscale;
const { maxUpscaleDimension } = config;
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) {
if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) {
// When these are missing, another warning will be shown
return false;
}
const { width, height } = imageDTO;
const { width, height } = imageWithDims;
const maxPixels = maxUpscaleDimension ** 2;
const upscaledPixels = width * scale * height * scale;
@@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
return upscaledPixels > maxPixels;
});
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]);
export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]);
return useAppSelector(selectIsTooLargeToUpscale);
};

View File

@@ -2,24 +2,32 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
import type { ControlNetModelConfig, ImageDTO } from 'services/api/types';
import type { ControlNetModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import z from 'zod';
export interface UpscaleState {
_version: 1;
upscaleModel: ParameterSpandrelImageToImageModel | null;
upscaleInitialImage: ImageDTO | null;
structure: number;
creativity: number;
tileControlnetModel: ControlNetModelConfig | null;
scale: number;
postProcessingModel: ParameterSpandrelImageToImageModel | null;
tileSize: number;
tileOverlap: number;
}
const zUpscaleState = z.object({
_version: z.literal(2),
upscaleModel: zModelIdentifierField.nullable(),
upscaleInitialImage: zImageWithDims.nullable(),
structure: z.number(),
creativity: z.number(),
tileControlnetModel: zModelIdentifierField.nullable(),
scale: z.number(),
postProcessingModel: zModelIdentifierField.nullable(),
tileSize: z.number(),
tileOverlap: z.number(),
});
export type UpscaleState = z.infer<typeof zUpscaleState>;
const getInitialState = (): UpscaleState => ({
_version: 1,
_version: 2,
upscaleModel: null,
upscaleInitialImage: null,
structure: 0,
@@ -38,7 +46,7 @@ const slice = createSlice({
upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.upscaleModel = action.payload;
},
upscaleInitialImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
upscaleInitialImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
state.upscaleInitialImage = action.payload;
},
structureChanged: (state, action: PayloadAction<number>) => {
@@ -77,19 +85,30 @@ export const {
tileOverlapChanged,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const upscaleSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zUpscaleState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state._version = 2;
// Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims
if (state.upscaleInitialImage) {
const { image_name, width, height } = state.upscaleInitialImage;
state.upscaleInitialImage = {
image_name,
width,
height,
};
}
}
return zUpscaleState.parse(state);
},
},
};