From a0b0c30be98238c3511fb097ae1dcab7f8993cec Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 May 2025 16:40:34 +1000 Subject: [PATCH] refactor(ui): move params state to big file of canvas zod stuff --- .../controlLayers/store/canvasSlice.ts | 10 +-- .../controlLayers/store/paramsSlice.ts | 89 ++----------------- .../src/features/controlLayers/store/types.ts | 75 +++++++++++++++- 3 files changed, 84 insertions(+), 90 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index cbcd0dcb26..1f49abfa87 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -72,7 +72,7 @@ import type { import { DEFAULT_ASPECT_RATIO_CONFIG, getEntityIdentifier, - getInitialState, + getInitialCanvasState, isChatGPT4oAspectRatioID, isFluxKontextAspectRatioID, isImagenAspectRatioID, @@ -97,7 +97,7 @@ import { export const canvasSlice = createSlice({ name: 'canvas', - initialState: getInitialState(), + initialState: getInitialCanvasState(), reducers: { // undoable canvas state //#region Raster layers @@ -1745,7 +1745,7 @@ export const canvasSlice = createSlice({ }, allEntitiesDeleted: (state) => { // Deleting all entities is equivalent to resetting the state for each entity type - const initialState = getInitialState(); + const initialState = getInitialCanvasState(); state.rasterLayers = initialState.rasterLayers; state.controlLayers = initialState.controlLayers; state.inpaintMasks = initialState.inpaintMasks; @@ -1809,7 +1809,7 @@ export const canvasSlice = createSlice({ }); const resetState = (state: CanvasState) => { - const newState = getInitialState(); + const newState = getInitialCanvasState(); // We need to retain the optimal dimension across resets, as it is changed only when the model changes. Copy it // from the old state, then recalculate the bbox size & scaled size. @@ -1931,7 +1931,7 @@ const migrate = (state: any): any => { export const canvasPersistConfig: PersistConfig = { name: canvasSlice.name, - initialState: getInitialState(), + initialState: getInitialCanvasState(), migrate, persistDenylist: [], }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 72060f38ad..a6a7e8d98d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -1,8 +1,8 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; -import { type RgbaColor, zRgbaColor } from 'features/controlLayers/store/types'; +import type { ParamsState, RgbaColor } from 'features/controlLayers/store/types'; +import { getInitialParamsState } from 'features/controlLayers/store/types'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, @@ -20,92 +20,13 @@ import type { ParameterT5EncoderModel, ParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; -import { - zParameterCanvasCoherenceMode, - zParameterCFGRescaleMultiplier, - zParameterCFGScale, - zParameterCLIPEmbedModel, - zParameterCLIPGEmbedModel, - zParameterCLIPLEmbedModel, - zParameterControlLoRAModel, - zParameterGuidance, - zParameterMaskBlurMethod, - zParameterModel, - zParameterNegativePrompt, - zParameterNegativeStylePromptSDXL, - zParameterPositivePrompt, - zParameterPositiveStylePromptSDXL, - zParameterPrecision, - zParameterScheduler, - zParameterSDXLRefinerModel, - zParameterSeed, - zParameterSteps, - zParameterStrength, - zParameterT5EncoderModel, - zParameterVAEModel, -} from 'features/parameters/types/parameterSchemas'; import { clamp } from 'lodash-es'; -import { z } from 'zod'; import { newSessionRequested } from './actions'; -const zParamsState = z.object({ - maskBlur: z.number().default(16), - maskBlurMethod: zParameterMaskBlurMethod.default('box'), - canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'), - canvasCoherenceMinDenoise: zParameterStrength.default(0), - canvasCoherenceEdgeSize: z.number().default(16), - infillMethod: z.string().default('lama'), - infillTileSize: z.number().default(32), - infillPatchmatchDownscaleSize: z.number().default(1), - infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }), - cfgScale: zParameterCFGScale.default(7.5), - cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0), - guidance: zParameterGuidance.default(4), - img2imgStrength: zParameterStrength.default(0.75), - optimizedDenoisingEnabled: z.boolean().default(true), - iterations: z.number().default(1), - scheduler: zParameterScheduler.default('dpmpp_3m_k'), - upscaleScheduler: zParameterScheduler.default('kdpm_2'), - upscaleCfgScale: zParameterCFGScale.default(2), - seed: zParameterSeed.default(0), - shouldRandomizeSeed: z.boolean().default(true), - steps: zParameterSteps.default(30), - model: zParameterModel.nullable().default(null), - vae: zParameterVAEModel.nullable().default(null), - vaePrecision: zParameterPrecision.default('fp32'), - fluxVAE: zParameterVAEModel.nullable().default(null), - seamlessXAxis: z.boolean().default(false), - seamlessYAxis: z.boolean().default(false), - clipSkip: z.number().default(0), - shouldUseCpuNoise: z.boolean().default(true), - positivePrompt: zParameterPositivePrompt.default(''), - negativePrompt: zParameterNegativePrompt.default(''), - positivePrompt2: zParameterPositiveStylePromptSDXL.default(''), - negativePrompt2: zParameterNegativeStylePromptSDXL.default(''), - shouldConcatPrompts: z.boolean().default(true), - refinerModel: zParameterSDXLRefinerModel.nullable().default(null), - refinerSteps: z.number().default(20), - refinerCFGScale: z.number().default(7.5), - refinerScheduler: zParameterScheduler.default('euler'), - refinerPositiveAestheticScore: z.number().default(6), - refinerNegativeAestheticScore: z.number().default(2.5), - refinerStart: z.number().default(0.8), - t5EncoderModel: zParameterT5EncoderModel.nullable().default(null), - clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null), - clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null), - clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null), - controlLora: zParameterControlLoRAModel.nullable().default(null), -}); - -export type ParamsState = z.infer; - -const INITIAL_STATE = zParamsState.parse({}); -const getInitialState = () => deepClone(INITIAL_STATE); - export const paramsSlice = createSlice({ name: 'params', - initialState: getInitialState(), + initialState: getInitialParamsState(), reducers: { setIterations: (state, action: PayloadAction) => { state.iterations = action.payload; @@ -273,7 +194,7 @@ export const paramsSlice = createSlice({ const resetState = (state: ParamsState): ParamsState => { // When a new session is requested, we need to keep the current model selections, plus dependent state // like VAE precision. Everything else gets reset to default. - const newState = getInitialState(); + const newState = getInitialParamsState(); newState.model = state.model; newState.vae = state.vae; newState.fluxVAE = state.fluxVAE; @@ -339,7 +260,7 @@ const migrate = (state: any): any => { export const paramsPersistConfig: PersistConfig = { name: paramsSlice.name, - initialState: getInitialState(), + initialState: getInitialParamsState(), migrate, persistDenylist: [], }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 7040938ab2..fca9e1a6ce 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -4,9 +4,29 @@ import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchi import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import { + zParameterCanvasCoherenceMode, + zParameterCFGRescaleMultiplier, + zParameterCFGScale, + zParameterCLIPEmbedModel, + zParameterCLIPGEmbedModel, + zParameterCLIPLEmbedModel, + zParameterControlLoRAModel, + zParameterGuidance, zParameterImageDimension, + zParameterMaskBlurMethod, + zParameterModel, zParameterNegativePrompt, + zParameterNegativeStylePromptSDXL, zParameterPositivePrompt, + zParameterPositiveStylePromptSDXL, + zParameterPrecision, + zParameterScheduler, + zParameterSDXLRefinerModel, + zParameterSeed, + zParameterSteps, + zParameterStrength, + zParameterT5EncoderModel, + zParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; import { getImageDTOSafe } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; @@ -467,6 +487,59 @@ const zBboxState = z.object({ scaleMethod: zBoundingBoxScaleMethod, modelBase: zMainModelBase, }); + +const zParamsState = z.object({ + maskBlur: z.number().default(16), + maskBlurMethod: zParameterMaskBlurMethod.default('box'), + canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'), + canvasCoherenceMinDenoise: zParameterStrength.default(0), + canvasCoherenceEdgeSize: z.number().default(16), + infillMethod: z.string().default('lama'), + infillTileSize: z.number().default(32), + infillPatchmatchDownscaleSize: z.number().default(1), + infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }), + cfgScale: zParameterCFGScale.default(7.5), + cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0), + guidance: zParameterGuidance.default(4), + img2imgStrength: zParameterStrength.default(0.75), + optimizedDenoisingEnabled: z.boolean().default(true), + iterations: z.number().default(1), + scheduler: zParameterScheduler.default('dpmpp_3m_k'), + upscaleScheduler: zParameterScheduler.default('kdpm_2'), + upscaleCfgScale: zParameterCFGScale.default(2), + seed: zParameterSeed.default(0), + shouldRandomizeSeed: z.boolean().default(true), + steps: zParameterSteps.default(30), + model: zParameterModel.nullable().default(null), + vae: zParameterVAEModel.nullable().default(null), + vaePrecision: zParameterPrecision.default('fp32'), + fluxVAE: zParameterVAEModel.nullable().default(null), + seamlessXAxis: z.boolean().default(false), + seamlessYAxis: z.boolean().default(false), + clipSkip: z.number().default(0), + shouldUseCpuNoise: z.boolean().default(true), + positivePrompt: zParameterPositivePrompt.default(''), + negativePrompt: zParameterNegativePrompt.default(''), + positivePrompt2: zParameterPositiveStylePromptSDXL.default(''), + negativePrompt2: zParameterNegativeStylePromptSDXL.default(''), + shouldConcatPrompts: z.boolean().default(true), + refinerModel: zParameterSDXLRefinerModel.nullable().default(null), + refinerSteps: z.number().default(20), + refinerCFGScale: z.number().default(7.5), + refinerScheduler: zParameterScheduler.default('euler'), + refinerPositiveAestheticScore: z.number().default(6), + refinerNegativeAestheticScore: z.number().default(2.5), + refinerStart: z.number().default(0.8), + t5EncoderModel: zParameterT5EncoderModel.nullable().default(null), + clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null), + clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null), + clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null), + controlLora: zParameterControlLoRAModel.nullable().default(null), +}); +export type ParamsState = z.infer; +const INITIAL_PARAMS_STATE = zParamsState.parse({}); +export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE); + const zCanvasState = z.object({ _version: z.literal(3).default(3), isSessionStarted: z.boolean().default(false), @@ -518,7 +591,7 @@ export type CanvasState = z.infer; * Gets a fresh canvas initial state with no references in memory to existing objects. */ const CANVAS_INITIAL_STATE = zCanvasState.parse({}); -export const getInitialState = () => deepClone(CANVAS_INITIAL_STATE); +export const getInitialCanvasState = () => deepClone(CANVAS_INITIAL_STATE); export const zCanvasMetadata = z.object({ inpaintMasks: z.array(zCanvasInpaintMaskState),