mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 22:35:27 -05:00
refactor(ui): zod-ify params slice state
This commit is contained in:
@@ -2,7 +2,7 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { RgbaColor } from 'features/controlLayers/store/types';
|
||||
import { type RgbaColor, zRgbaColor } from 'features/controlLayers/store/types';
|
||||
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
|
||||
import type {
|
||||
ParameterCanvasCoherenceMode,
|
||||
@@ -13,126 +13,99 @@ import type {
|
||||
ParameterCLIPLEmbedModel,
|
||||
ParameterControlLoRAModel,
|
||||
ParameterGuidance,
|
||||
ParameterMaskBlurMethod,
|
||||
ParameterModel,
|
||||
ParameterNegativePrompt,
|
||||
ParameterNegativeStylePromptSDXL,
|
||||
ParameterPositivePrompt,
|
||||
ParameterPositiveStylePromptSDXL,
|
||||
ParameterPrecision,
|
||||
ParameterScheduler,
|
||||
ParameterSDXLRefinerModel,
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
zParameterCanvasCoherenceMode,
|
||||
zParameterCFGRescaleMultiplier,
|
||||
zParameterCFGScale,
|
||||
zParameterCLIPEmbedModel,
|
||||
zParameterCLIPGEmbedModel,
|
||||
zParameterCLIPLEmbedModel,
|
||||
zParameterControlLoRAModel,
|
||||
zParameterGuidance,
|
||||
zParameterMaskBlurMethod,
|
||||
zParameterModel,
|
||||
zParameterNegativePrompt,
|
||||
zParameterNegativeStylePromptSDXL,
|
||||
zParameterPositivePrompt,
|
||||
zParameterPositiveStylePromptSDXL,
|
||||
zParameterPrecision,
|
||||
zParameterScheduler,
|
||||
zParameterSDXLRefinerModel,
|
||||
zParameterSeed,
|
||||
zParameterSteps,
|
||||
zParameterStrength,
|
||||
zParameterT5EncoderModel,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { newSessionRequested } from './actions';
|
||||
|
||||
export type ParamsState = {
|
||||
maskBlur: number;
|
||||
maskBlurMethod: ParameterMaskBlurMethod;
|
||||
canvasCoherenceMode: ParameterCanvasCoherenceMode;
|
||||
canvasCoherenceMinDenoise: ParameterStrength;
|
||||
canvasCoherenceEdgeSize: number;
|
||||
infillMethod: string;
|
||||
infillTileSize: number;
|
||||
infillPatchmatchDownscaleSize: number;
|
||||
infillColorValue: RgbaColor;
|
||||
cfgScale: ParameterCFGScale;
|
||||
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||
guidance: ParameterGuidance;
|
||||
img2imgStrength: ParameterStrength;
|
||||
optimizedDenoisingEnabled: boolean;
|
||||
iterations: number;
|
||||
scheduler: ParameterScheduler;
|
||||
upscaleScheduler: ParameterScheduler;
|
||||
upscaleCfgScale: ParameterCFGScale;
|
||||
seed: ParameterSeed;
|
||||
shouldRandomizeSeed: boolean;
|
||||
steps: ParameterSteps;
|
||||
model: ParameterModel | null;
|
||||
vae: ParameterVAEModel | null;
|
||||
vaePrecision: ParameterPrecision;
|
||||
fluxVAE: ParameterVAEModel | null;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
clipSkip: number;
|
||||
shouldUseCpuNoise: boolean;
|
||||
positivePrompt: ParameterPositivePrompt;
|
||||
negativePrompt: ParameterNegativePrompt;
|
||||
positivePrompt2: ParameterPositiveStylePromptSDXL;
|
||||
negativePrompt2: ParameterNegativeStylePromptSDXL;
|
||||
shouldConcatPrompts: boolean;
|
||||
refinerModel: ParameterSDXLRefinerModel | null;
|
||||
refinerSteps: number;
|
||||
refinerCFGScale: number;
|
||||
refinerScheduler: ParameterScheduler;
|
||||
refinerPositiveAestheticScore: number;
|
||||
refinerNegativeAestheticScore: number;
|
||||
refinerStart: number;
|
||||
t5EncoderModel: ParameterT5EncoderModel | null;
|
||||
clipEmbedModel: ParameterCLIPEmbedModel | null;
|
||||
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
|
||||
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
|
||||
controlLora: ParameterControlLoRAModel | null;
|
||||
};
|
||||
const zParamsState = z.object({
|
||||
maskBlur: z.number().default(16),
|
||||
maskBlurMethod: zParameterMaskBlurMethod.default('box'),
|
||||
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'),
|
||||
canvasCoherenceMinDenoise: zParameterStrength.default(0),
|
||||
canvasCoherenceEdgeSize: z.number().default(16),
|
||||
infillMethod: z.string().default('lama'),
|
||||
infillTileSize: z.number().default(32),
|
||||
infillPatchmatchDownscaleSize: z.number().default(1),
|
||||
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }),
|
||||
cfgScale: zParameterCFGScale.default(7.5),
|
||||
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0),
|
||||
guidance: zParameterGuidance.default(4),
|
||||
img2imgStrength: zParameterStrength.default(0.75),
|
||||
optimizedDenoisingEnabled: z.boolean().default(true),
|
||||
iterations: z.number().default(1),
|
||||
scheduler: zParameterScheduler.default('dpmpp_3m_k'),
|
||||
upscaleScheduler: zParameterScheduler.default('kdpm_2'),
|
||||
upscaleCfgScale: zParameterCFGScale.default(2),
|
||||
seed: zParameterSeed.default(0),
|
||||
shouldRandomizeSeed: z.boolean().default(true),
|
||||
steps: zParameterSteps.default(30),
|
||||
model: zParameterModel.nullable().default(null),
|
||||
vae: zParameterVAEModel.nullable().default(null),
|
||||
vaePrecision: zParameterPrecision.default('fp32'),
|
||||
fluxVAE: zParameterVAEModel.nullable().default(null),
|
||||
seamlessXAxis: z.boolean().default(false),
|
||||
seamlessYAxis: z.boolean().default(false),
|
||||
clipSkip: z.number().default(0),
|
||||
shouldUseCpuNoise: z.boolean().default(true),
|
||||
positivePrompt: zParameterPositivePrompt.default(''),
|
||||
negativePrompt: zParameterNegativePrompt.default(''),
|
||||
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''),
|
||||
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''),
|
||||
shouldConcatPrompts: z.boolean().default(true),
|
||||
refinerModel: zParameterSDXLRefinerModel.nullable().default(null),
|
||||
refinerSteps: z.number().default(20),
|
||||
refinerCFGScale: z.number().default(7.5),
|
||||
refinerScheduler: zParameterScheduler.default('euler'),
|
||||
refinerPositiveAestheticScore: z.number().default(6),
|
||||
refinerNegativeAestheticScore: z.number().default(2.5),
|
||||
refinerStart: z.number().default(0.8),
|
||||
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null),
|
||||
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null),
|
||||
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null),
|
||||
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null),
|
||||
controlLora: zParameterControlLoRAModel.nullable().default(null),
|
||||
});
|
||||
|
||||
const initialState: ParamsState = {
|
||||
maskBlur: 16,
|
||||
maskBlurMethod: 'box',
|
||||
canvasCoherenceMode: 'Gaussian Blur',
|
||||
canvasCoherenceMinDenoise: 0,
|
||||
canvasCoherenceEdgeSize: 16,
|
||||
infillMethod: 'lama',
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
cfgScale: 7.5,
|
||||
cfgRescaleMultiplier: 0,
|
||||
guidance: 4,
|
||||
img2imgStrength: 0.75,
|
||||
optimizedDenoisingEnabled: true,
|
||||
iterations: 1,
|
||||
scheduler: 'dpmpp_3m_k',
|
||||
upscaleScheduler: 'kdpm_2',
|
||||
upscaleCfgScale: 2,
|
||||
seed: 0,
|
||||
shouldRandomizeSeed: true,
|
||||
steps: 30,
|
||||
model: null,
|
||||
vae: null,
|
||||
fluxVAE: null,
|
||||
vaePrecision: 'fp32',
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
shouldUseCpuNoise: true,
|
||||
positivePrompt: '',
|
||||
negativePrompt: '',
|
||||
positivePrompt2: '',
|
||||
negativePrompt2: '',
|
||||
shouldConcatPrompts: true,
|
||||
refinerModel: null,
|
||||
refinerSteps: 20,
|
||||
refinerCFGScale: 7.5,
|
||||
refinerScheduler: 'euler',
|
||||
refinerPositiveAestheticScore: 6,
|
||||
refinerNegativeAestheticScore: 2.5,
|
||||
refinerStart: 0.8,
|
||||
t5EncoderModel: null,
|
||||
clipEmbedModel: null,
|
||||
clipLEmbedModel: null,
|
||||
clipGEmbedModel: null,
|
||||
controlLora: null,
|
||||
};
|
||||
export type ParamsState = z.infer<typeof zParamsState>;
|
||||
|
||||
const INITIAL_STATE = zParamsState.parse({});
|
||||
const getInitialState = () => deepClone(INITIAL_STATE);
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
name: 'params',
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
setIterations: (state, action: PayloadAction<number>) => {
|
||||
state.iterations = action.payload;
|
||||
@@ -300,7 +273,7 @@ export const paramsSlice = createSlice({
|
||||
const resetState = (state: ParamsState): ParamsState => {
|
||||
// When a new session is requested, we need to keep the current model selections, plus dependent state
|
||||
// like VAE precision. Everything else gets reset to default.
|
||||
const newState = deepClone(initialState);
|
||||
const newState = getInitialState();
|
||||
newState.model = state.model;
|
||||
newState.vae = state.vae;
|
||||
newState.fluxVAE = state.fluxVAE;
|
||||
@@ -366,7 +339,7 @@ const migrate = (state: any): any => {
|
||||
|
||||
export const paramsPersistConfig: PersistConfig<ParamsState> = {
|
||||
name: paramsSlice.name,
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
@@ -72,7 +72,7 @@ const zRgbColor = z.object({
|
||||
b: z.number().int().min(0).max(255),
|
||||
});
|
||||
export type RgbColor = z.infer<typeof zRgbColor>;
|
||||
const zRgbaColor = zRgbColor.extend({
|
||||
export const zRgbaColor = zRgbColor.extend({
|
||||
a: z.number().min(0).max(1),
|
||||
});
|
||||
export type RgbaColor = z.infer<typeof zRgbaColor>;
|
||||
|
||||
@@ -96,7 +96,7 @@ export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Model
|
||||
const zParameterSDXLRefinerModel = zModelIdentifierField;
|
||||
export const zParameterSDXLRefinerModel = zModelIdentifierField;
|
||||
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
|
||||
// #endregion
|
||||
|
||||
@@ -188,7 +188,7 @@ export type ParameterSDXLRefinerStart = z.infer<typeof zParameterSDXLRefinerStar
|
||||
// #endregion
|
||||
|
||||
// #region Mask Blur Method
|
||||
const zParameterMaskBlurMethod = z.enum(['box', 'gaussian']);
|
||||
export const [zParameterMaskBlurMethod, isParameterMaskBlurMethod] = buildParameter(z.enum(['box', 'gaussian']));
|
||||
export type ParameterMaskBlurMethod = z.infer<typeof zParameterMaskBlurMethod>;
|
||||
// #endregion
|
||||
|
||||
|
||||
Reference in New Issue
Block a user