refactor(ui): zod-ify params slice state

This commit is contained in:
psychedelicious
2025-05-21 16:38:07 +10:00
parent 5f2f12f803
commit 5c4cbc7fa2
3 changed files with 84 additions and 111 deletions

View File

@@ -2,7 +2,7 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import type { RgbaColor } from 'features/controlLayers/store/types';
import { type RgbaColor, zRgbaColor } from 'features/controlLayers/store/types';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
import type {
ParameterCanvasCoherenceMode,
@@ -13,126 +13,99 @@ import type {
ParameterCLIPLEmbedModel,
ParameterControlLoRAModel,
ParameterGuidance,
ParameterMaskBlurMethod,
ParameterModel,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterPrecision,
ParameterScheduler,
ParameterSDXLRefinerModel,
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import {
zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier,
zParameterCFGScale,
zParameterCLIPEmbedModel,
zParameterCLIPGEmbedModel,
zParameterCLIPLEmbedModel,
zParameterControlLoRAModel,
zParameterGuidance,
zParameterMaskBlurMethod,
zParameterModel,
zParameterNegativePrompt,
zParameterNegativeStylePromptSDXL,
zParameterPositivePrompt,
zParameterPositiveStylePromptSDXL,
zParameterPrecision,
zParameterScheduler,
zParameterSDXLRefinerModel,
zParameterSeed,
zParameterSteps,
zParameterStrength,
zParameterT5EncoderModel,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { clamp } from 'lodash-es';
import { z } from 'zod';
import { newSessionRequested } from './actions';
export type ParamsState = {
maskBlur: number;
maskBlurMethod: ParameterMaskBlurMethod;
canvasCoherenceMode: ParameterCanvasCoherenceMode;
canvasCoherenceMinDenoise: ParameterStrength;
canvasCoherenceEdgeSize: number;
infillMethod: string;
infillTileSize: number;
infillPatchmatchDownscaleSize: number;
infillColorValue: RgbaColor;
cfgScale: ParameterCFGScale;
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
guidance: ParameterGuidance;
img2imgStrength: ParameterStrength;
optimizedDenoisingEnabled: boolean;
iterations: number;
scheduler: ParameterScheduler;
upscaleScheduler: ParameterScheduler;
upscaleCfgScale: ParameterCFGScale;
seed: ParameterSeed;
shouldRandomizeSeed: boolean;
steps: ParameterSteps;
model: ParameterModel | null;
vae: ParameterVAEModel | null;
vaePrecision: ParameterPrecision;
fluxVAE: ParameterVAEModel | null;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
clipSkip: number;
shouldUseCpuNoise: boolean;
positivePrompt: ParameterPositivePrompt;
negativePrompt: ParameterNegativePrompt;
positivePrompt2: ParameterPositiveStylePromptSDXL;
negativePrompt2: ParameterNegativeStylePromptSDXL;
shouldConcatPrompts: boolean;
refinerModel: ParameterSDXLRefinerModel | null;
refinerSteps: number;
refinerCFGScale: number;
refinerScheduler: ParameterScheduler;
refinerPositiveAestheticScore: number;
refinerNegativeAestheticScore: number;
refinerStart: number;
t5EncoderModel: ParameterT5EncoderModel | null;
clipEmbedModel: ParameterCLIPEmbedModel | null;
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
controlLora: ParameterControlLoRAModel | null;
};
const zParamsState = z.object({
maskBlur: z.number().default(16),
maskBlurMethod: zParameterMaskBlurMethod.default('box'),
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'),
canvasCoherenceMinDenoise: zParameterStrength.default(0),
canvasCoherenceEdgeSize: z.number().default(16),
infillMethod: z.string().default('lama'),
infillTileSize: z.number().default(32),
infillPatchmatchDownscaleSize: z.number().default(1),
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }),
cfgScale: zParameterCFGScale.default(7.5),
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0),
guidance: zParameterGuidance.default(4),
img2imgStrength: zParameterStrength.default(0.75),
optimizedDenoisingEnabled: z.boolean().default(true),
iterations: z.number().default(1),
scheduler: zParameterScheduler.default('dpmpp_3m_k'),
upscaleScheduler: zParameterScheduler.default('kdpm_2'),
upscaleCfgScale: zParameterCFGScale.default(2),
seed: zParameterSeed.default(0),
shouldRandomizeSeed: z.boolean().default(true),
steps: zParameterSteps.default(30),
model: zParameterModel.nullable().default(null),
vae: zParameterVAEModel.nullable().default(null),
vaePrecision: zParameterPrecision.default('fp32'),
fluxVAE: zParameterVAEModel.nullable().default(null),
seamlessXAxis: z.boolean().default(false),
seamlessYAxis: z.boolean().default(false),
clipSkip: z.number().default(0),
shouldUseCpuNoise: z.boolean().default(true),
positivePrompt: zParameterPositivePrompt.default(''),
negativePrompt: zParameterNegativePrompt.default(''),
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''),
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''),
shouldConcatPrompts: z.boolean().default(true),
refinerModel: zParameterSDXLRefinerModel.nullable().default(null),
refinerSteps: z.number().default(20),
refinerCFGScale: z.number().default(7.5),
refinerScheduler: zParameterScheduler.default('euler'),
refinerPositiveAestheticScore: z.number().default(6),
refinerNegativeAestheticScore: z.number().default(2.5),
refinerStart: z.number().default(0.8),
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null),
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null),
controlLora: zParameterControlLoRAModel.nullable().default(null),
});
const initialState: ParamsState = {
maskBlur: 16,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
guidance: 4,
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
fluxVAE: null,
vaePrecision: 'fp32',
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
shouldUseCpuNoise: true,
positivePrompt: '',
negativePrompt: '',
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
controlLora: null,
};
export type ParamsState = z.infer<typeof zParamsState>;
const INITIAL_STATE = zParamsState.parse({});
const getInitialState = () => deepClone(INITIAL_STATE);
export const paramsSlice = createSlice({
name: 'params',
initialState,
initialState: getInitialState(),
reducers: {
setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload;
@@ -300,7 +273,7 @@ export const paramsSlice = createSlice({
const resetState = (state: ParamsState): ParamsState => {
// When a new session is requested, we need to keep the current model selections, plus dependent state
// like VAE precision. Everything else gets reset to default.
const newState = deepClone(initialState);
const newState = getInitialState();
newState.model = state.model;
newState.vae = state.vae;
newState.fluxVAE = state.fluxVAE;
@@ -366,7 +339,7 @@ const migrate = (state: any): any => {
export const paramsPersistConfig: PersistConfig<ParamsState> = {
name: paramsSlice.name,
initialState,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};

View File

@@ -72,7 +72,7 @@ const zRgbColor = z.object({
b: z.number().int().min(0).max(255),
});
export type RgbColor = z.infer<typeof zRgbColor>;
const zRgbaColor = zRgbColor.extend({
export const zRgbaColor = zRgbColor.extend({
a: z.number().min(0).max(1),
});
export type RgbaColor = z.infer<typeof zRgbaColor>;

View File

@@ -96,7 +96,7 @@ export type ParameterModel = z.infer<typeof zParameterModel>;
// #endregion
// #region SDXL Refiner Model
const zParameterSDXLRefinerModel = zModelIdentifierField;
export const zParameterSDXLRefinerModel = zModelIdentifierField;
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
// #endregion
@@ -188,7 +188,7 @@ export type ParameterSDXLRefinerStart = z.infer<typeof zParameterSDXLRefinerStar
// #endregion
// #region Mask Blur Method
const zParameterMaskBlurMethod = z.enum(['box', 'gaussian']);
export const [zParameterMaskBlurMethod, isParameterMaskBlurMethod] = buildParameter(z.enum(['box', 'gaussian']));
export type ParameterMaskBlurMethod = z.infer<typeof zParameterMaskBlurMethod>;
// #endregion