refactor(ui): use zod for all redux state

This commit is contained in:
psychedelicious
2025-07-25 13:13:28 +10:00
parent 6962536b4a
commit aed9b1013e
39 changed files with 488 additions and 299 deletions

View File

@@ -112,7 +112,7 @@ const getInitialState = (): CanvasSettingsState => ({
pressureSensitivity: true,
ruleOfThirds: false,
saveAllImagesToGallery: false,
stagingAreaAutoSwitch: 'switch_on_start' as const,
stagingAreaAutoSwitch: 'switch_on_start',
});
const slice = createSlice({
@@ -209,7 +209,7 @@ export const {
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zCanvasSettingsState,
schema: zCanvasSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => zCanvasSettingsState.parse(state),

View File

@@ -1720,7 +1720,7 @@ const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
export const canvasSliceConfig: SliceConfig<typeof slice> = {
slice,
getInitialState: getInitialCanvasState,
zSchema: zCanvasState,
schema: zCanvasState,
persistConfig: {
migrate: (state) => zCanvasState.parse(state),
},

View File

@@ -3,19 +3,25 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import type { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
import z from 'zod';
const zCanvasStagingAreaState = z.object({
_version: z.literal(1).default(1),
canvasSessionId: z.string().default(() => getPrefixedId('canvas')),
canvasDiscardedQueueItems: z.array(z.number().int()).default(() => []),
_version: z.literal(1),
canvasSessionId: z.string(),
canvasDiscardedQueueItems: z.array(z.number().int()),
});
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
const getInitialState = (): CanvasStagingAreaState => zCanvasStagingAreaState.parse({});
const getInitialState = (): CanvasStagingAreaState => ({
_version: 1,
canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [],
});
const slice = createSlice({
name: 'canvasSession',
@@ -48,18 +54,17 @@ export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zCanvasStagingAreaState,
schema: zCanvasStagingAreaState,
getInitialState,
persistConfig: {
migrate: (state) => {
{
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return zCanvasStagingAreaState.parse(state);
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return zCanvasStagingAreaState.parse(state);
},
},
};

View File

@@ -2,14 +2,16 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolki
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
type LoRAsState = {
loras: LoRA[];
};
const zLoRAsState = z.object({
loras: z.array(zLoRA),
});
type LoRAsState = z.infer<typeof zLoRAsState>;
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
@@ -74,16 +76,12 @@ const slice = createSlice({
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zLoRAsState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => zLoRAsState.parse(state),
},
};

View File

@@ -403,7 +403,7 @@ export const {
export const paramsSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zParamsState,
schema: zParamsState,
getInitialState: getInitialParamsState,
persistConfig: {
migrate: (state) => zParamsState.parse(state),

View File

@@ -266,17 +266,12 @@ export const {
refImagesRecalled,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zRefImagesState,
schema: zRefImagesState,
getInitialState: getInitialRefImagesState,
persistConfig: {
migrate,
migrate: (state) => zRefImagesState.parse(state),
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
},
};

View File

@@ -3,7 +3,6 @@ import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEnt
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier,
@@ -45,7 +44,7 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async
}
});
const zImageWithDims = z
export const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
@@ -424,12 +423,13 @@ export const zCanvasEntityIdentifer = z.object({
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type LoRA = {
id: string;
isEnabled: boolean;
model: ParameterLoRAModel;
weight: number;
};
export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zServerValidatedModelIdentifierField,
weight: z.number().gte(-1).lte(2),
});
export type LoRA = z.infer<typeof zLoRA>;
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
@@ -574,11 +574,11 @@ export const zParamsState = z.object({
export type ParamsState = z.infer<typeof zParamsState>;
export const getInitialParamsState = (): ParamsState => ({
maskBlur: 16,
maskBlurMethod: 'box' as const,
canvasCoherenceMode: 'Gaussian Blur' as const,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama' as const,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
@@ -588,15 +588,15 @@ export const getInitialParamsState = (): ParamsState => ({
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k' as const,
upscaleScheduler: 'kdpm_2' as const,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
vaePrecision: 'fp32' as const,
vaePrecision: 'fp32',
fluxVAE: null,
seamlessXAxis: false,
seamlessYAxis: false,
@@ -610,7 +610,7 @@ export const getInitialParamsState = (): ParamsState => ({
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler' as const,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
@@ -653,7 +653,7 @@ export const zCanvasState = z.object({
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const getInitialCanvasState = (): CanvasState => ({
_version: 3 as const,
_version: 3,
selectedEntityIdentifier: null,
bookmarkedEntityIdentifier: null,
inpaintMasks: { isHidden: false, entities: [] },
@@ -663,9 +663,9 @@ export const getInitialCanvasState = (): CanvasState => ({
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
scaleMethod: 'auto' as const,
scaleMethod: 'auto',
scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1' as const,
modelBase: 'sd-1',
},
});