From 5c4cbc7fa22bc14429c62317ff38344353c8c54e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 May 2025 16:38:07 +1000 Subject: [PATCH] refactor(ui): zod-ify params slice state --- .../controlLayers/store/paramsSlice.ts | 189 ++++++++---------- .../src/features/controlLayers/store/types.ts | 2 +- .../parameters/types/parameterSchemas.ts | 4 +- 3 files changed, 84 insertions(+), 111 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 069e8fc366..72060f38ad 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -2,7 +2,7 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; -import type { RgbaColor } from 'features/controlLayers/store/types'; +import { type RgbaColor, zRgbaColor } from 'features/controlLayers/store/types'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, @@ -13,126 +13,99 @@ import type { ParameterCLIPLEmbedModel, ParameterControlLoRAModel, ParameterGuidance, - ParameterMaskBlurMethod, ParameterModel, - ParameterNegativePrompt, - ParameterNegativeStylePromptSDXL, - ParameterPositivePrompt, - ParameterPositiveStylePromptSDXL, ParameterPrecision, ParameterScheduler, ParameterSDXLRefinerModel, - ParameterSeed, - ParameterSteps, - ParameterStrength, ParameterT5EncoderModel, ParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; +import { + zParameterCanvasCoherenceMode, + zParameterCFGRescaleMultiplier, + zParameterCFGScale, + zParameterCLIPEmbedModel, + zParameterCLIPGEmbedModel, + zParameterCLIPLEmbedModel, + zParameterControlLoRAModel, + zParameterGuidance, + zParameterMaskBlurMethod, + zParameterModel, + zParameterNegativePrompt, + zParameterNegativeStylePromptSDXL, + zParameterPositivePrompt, + zParameterPositiveStylePromptSDXL, + zParameterPrecision, + zParameterScheduler, + zParameterSDXLRefinerModel, + zParameterSeed, + zParameterSteps, + zParameterStrength, + zParameterT5EncoderModel, + zParameterVAEModel, +} from 'features/parameters/types/parameterSchemas'; import { clamp } from 'lodash-es'; +import { z } from 'zod'; import { newSessionRequested } from './actions'; -export type ParamsState = { - maskBlur: number; - maskBlurMethod: ParameterMaskBlurMethod; - canvasCoherenceMode: ParameterCanvasCoherenceMode; - canvasCoherenceMinDenoise: ParameterStrength; - canvasCoherenceEdgeSize: number; - infillMethod: string; - infillTileSize: number; - infillPatchmatchDownscaleSize: number; - infillColorValue: RgbaColor; - cfgScale: ParameterCFGScale; - cfgRescaleMultiplier: ParameterCFGRescaleMultiplier; - guidance: ParameterGuidance; - img2imgStrength: ParameterStrength; - optimizedDenoisingEnabled: boolean; - iterations: number; - scheduler: ParameterScheduler; - upscaleScheduler: ParameterScheduler; - upscaleCfgScale: ParameterCFGScale; - seed: ParameterSeed; - shouldRandomizeSeed: boolean; - steps: ParameterSteps; - model: ParameterModel | null; - vae: ParameterVAEModel | null; - vaePrecision: ParameterPrecision; - fluxVAE: ParameterVAEModel | null; - seamlessXAxis: boolean; - seamlessYAxis: boolean; - clipSkip: number; - shouldUseCpuNoise: boolean; - positivePrompt: ParameterPositivePrompt; - negativePrompt: ParameterNegativePrompt; - positivePrompt2: ParameterPositiveStylePromptSDXL; - negativePrompt2: ParameterNegativeStylePromptSDXL; - shouldConcatPrompts: boolean; - refinerModel: ParameterSDXLRefinerModel | null; - refinerSteps: number; - refinerCFGScale: number; - refinerScheduler: ParameterScheduler; - refinerPositiveAestheticScore: number; - refinerNegativeAestheticScore: number; - refinerStart: number; - t5EncoderModel: ParameterT5EncoderModel | null; - clipEmbedModel: ParameterCLIPEmbedModel | null; - clipLEmbedModel: ParameterCLIPLEmbedModel | null; - clipGEmbedModel: ParameterCLIPGEmbedModel | null; - controlLora: ParameterControlLoRAModel | null; -}; +const zParamsState = z.object({ + maskBlur: z.number().default(16), + maskBlurMethod: zParameterMaskBlurMethod.default('box'), + canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'), + canvasCoherenceMinDenoise: zParameterStrength.default(0), + canvasCoherenceEdgeSize: z.number().default(16), + infillMethod: z.string().default('lama'), + infillTileSize: z.number().default(32), + infillPatchmatchDownscaleSize: z.number().default(1), + infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }), + cfgScale: zParameterCFGScale.default(7.5), + cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0), + guidance: zParameterGuidance.default(4), + img2imgStrength: zParameterStrength.default(0.75), + optimizedDenoisingEnabled: z.boolean().default(true), + iterations: z.number().default(1), + scheduler: zParameterScheduler.default('dpmpp_3m_k'), + upscaleScheduler: zParameterScheduler.default('kdpm_2'), + upscaleCfgScale: zParameterCFGScale.default(2), + seed: zParameterSeed.default(0), + shouldRandomizeSeed: z.boolean().default(true), + steps: zParameterSteps.default(30), + model: zParameterModel.nullable().default(null), + vae: zParameterVAEModel.nullable().default(null), + vaePrecision: zParameterPrecision.default('fp32'), + fluxVAE: zParameterVAEModel.nullable().default(null), + seamlessXAxis: z.boolean().default(false), + seamlessYAxis: z.boolean().default(false), + clipSkip: z.number().default(0), + shouldUseCpuNoise: z.boolean().default(true), + positivePrompt: zParameterPositivePrompt.default(''), + negativePrompt: zParameterNegativePrompt.default(''), + positivePrompt2: zParameterPositiveStylePromptSDXL.default(''), + negativePrompt2: zParameterNegativeStylePromptSDXL.default(''), + shouldConcatPrompts: z.boolean().default(true), + refinerModel: zParameterSDXLRefinerModel.nullable().default(null), + refinerSteps: z.number().default(20), + refinerCFGScale: z.number().default(7.5), + refinerScheduler: zParameterScheduler.default('euler'), + refinerPositiveAestheticScore: z.number().default(6), + refinerNegativeAestheticScore: z.number().default(2.5), + refinerStart: z.number().default(0.8), + t5EncoderModel: zParameterT5EncoderModel.nullable().default(null), + clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null), + clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null), + clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null), + controlLora: zParameterControlLoRAModel.nullable().default(null), +}); -const initialState: ParamsState = { - maskBlur: 16, - maskBlurMethod: 'box', - canvasCoherenceMode: 'Gaussian Blur', - canvasCoherenceMinDenoise: 0, - canvasCoherenceEdgeSize: 16, - infillMethod: 'lama', - infillTileSize: 32, - infillPatchmatchDownscaleSize: 1, - infillColorValue: { r: 0, g: 0, b: 0, a: 1 }, - cfgScale: 7.5, - cfgRescaleMultiplier: 0, - guidance: 4, - img2imgStrength: 0.75, - optimizedDenoisingEnabled: true, - iterations: 1, - scheduler: 'dpmpp_3m_k', - upscaleScheduler: 'kdpm_2', - upscaleCfgScale: 2, - seed: 0, - shouldRandomizeSeed: true, - steps: 30, - model: null, - vae: null, - fluxVAE: null, - vaePrecision: 'fp32', - seamlessXAxis: false, - seamlessYAxis: false, - clipSkip: 0, - shouldUseCpuNoise: true, - positivePrompt: '', - negativePrompt: '', - positivePrompt2: '', - negativePrompt2: '', - shouldConcatPrompts: true, - refinerModel: null, - refinerSteps: 20, - refinerCFGScale: 7.5, - refinerScheduler: 'euler', - refinerPositiveAestheticScore: 6, - refinerNegativeAestheticScore: 2.5, - refinerStart: 0.8, - t5EncoderModel: null, - clipEmbedModel: null, - clipLEmbedModel: null, - clipGEmbedModel: null, - controlLora: null, -}; +export type ParamsState = z.infer; + +const INITIAL_STATE = zParamsState.parse({}); +const getInitialState = () => deepClone(INITIAL_STATE); export const paramsSlice = createSlice({ name: 'params', - initialState, + initialState: getInitialState(), reducers: { setIterations: (state, action: PayloadAction) => { state.iterations = action.payload; @@ -300,7 +273,7 @@ export const paramsSlice = createSlice({ const resetState = (state: ParamsState): ParamsState => { // When a new session is requested, we need to keep the current model selections, plus dependent state // like VAE precision. Everything else gets reset to default. - const newState = deepClone(initialState); + const newState = getInitialState(); newState.model = state.model; newState.vae = state.vae; newState.fluxVAE = state.fluxVAE; @@ -366,7 +339,7 @@ const migrate = (state: any): any => { export const paramsPersistConfig: PersistConfig = { name: paramsSlice.name, - initialState, + initialState: getInitialState(), migrate, persistDenylist: [], }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index a4aee48086..7040938ab2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -72,7 +72,7 @@ const zRgbColor = z.object({ b: z.number().int().min(0).max(255), }); export type RgbColor = z.infer; -const zRgbaColor = zRgbColor.extend({ +export const zRgbaColor = zRgbColor.extend({ a: z.number().min(0).max(1), }); export type RgbaColor = z.infer; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index acc02785b9..3daf1f4137 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -96,7 +96,7 @@ export type ParameterModel = z.infer; // #endregion // #region SDXL Refiner Model -const zParameterSDXLRefinerModel = zModelIdentifierField; +export const zParameterSDXLRefinerModel = zModelIdentifierField; export type ParameterSDXLRefinerModel = z.infer; // #endregion @@ -188,7 +188,7 @@ export type ParameterSDXLRefinerStart = z.infer; // #endregion