mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-17 14:34:39 -05:00
refactor(ui): use zod for all redux state
This commit is contained in:
@@ -112,7 +112,7 @@ const getInitialState = (): CanvasSettingsState => ({
|
||||
pressureSensitivity: true,
|
||||
ruleOfThirds: false,
|
||||
saveAllImagesToGallery: false,
|
||||
stagingAreaAutoSwitch: 'switch_on_start' as const,
|
||||
stagingAreaAutoSwitch: 'switch_on_start',
|
||||
});
|
||||
|
||||
const slice = createSlice({
|
||||
@@ -209,7 +209,7 @@ export const {
|
||||
|
||||
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
zSchema: zCanvasSettingsState,
|
||||
schema: zCanvasSettingsState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zCanvasSettingsState.parse(state),
|
||||
|
||||
@@ -1720,7 +1720,7 @@ const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
|
||||
export const canvasSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
getInitialState: getInitialCanvasState,
|
||||
zSchema: zCanvasState,
|
||||
schema: zCanvasState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zCanvasState.parse(state),
|
||||
},
|
||||
|
||||
@@ -3,19 +3,25 @@ import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { useMemo } from 'react';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
const zCanvasStagingAreaState = z.object({
|
||||
_version: z.literal(1).default(1),
|
||||
canvasSessionId: z.string().default(() => getPrefixedId('canvas')),
|
||||
canvasDiscardedQueueItems: z.array(z.number().int()).default(() => []),
|
||||
_version: z.literal(1),
|
||||
canvasSessionId: z.string(),
|
||||
canvasDiscardedQueueItems: z.array(z.number().int()),
|
||||
});
|
||||
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
|
||||
|
||||
const getInitialState = (): CanvasStagingAreaState => zCanvasStagingAreaState.parse({});
|
||||
const getInitialState = (): CanvasStagingAreaState => ({
|
||||
_version: 1,
|
||||
canvasSessionId: getPrefixedId('canvas'),
|
||||
canvasDiscardedQueueItems: [],
|
||||
});
|
||||
|
||||
const slice = createSlice({
|
||||
name: 'canvasSession',
|
||||
@@ -48,18 +54,17 @@ export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
|
||||
|
||||
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
zSchema: zCanvasStagingAreaState,
|
||||
schema: zCanvasStagingAreaState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
{
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
|
||||
}
|
||||
|
||||
return zCanvasStagingAreaState.parse(state);
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
|
||||
}
|
||||
|
||||
return zCanvasStagingAreaState.parse(state);
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -2,14 +2,16 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolki
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import z from 'zod';
|
||||
|
||||
type LoRAsState = {
|
||||
loras: LoRA[];
|
||||
};
|
||||
const zLoRAsState = z.object({
|
||||
loras: z.array(zLoRA),
|
||||
});
|
||||
type LoRAsState = z.infer<typeof zLoRAsState>;
|
||||
|
||||
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
@@ -74,16 +76,12 @@ const slice = createSlice({
|
||||
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
|
||||
slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const lorasSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zLoRAsState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate,
|
||||
migrate: (state) => zLoRAsState.parse(state),
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -403,7 +403,7 @@ export const {
|
||||
|
||||
export const paramsSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
zSchema: zParamsState,
|
||||
schema: zParamsState,
|
||||
getInitialState: getInitialParamsState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zParamsState.parse(state),
|
||||
|
||||
@@ -266,17 +266,12 @@ export const {
|
||||
refImagesRecalled,
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
zSchema: zRefImagesState,
|
||||
schema: zRefImagesState,
|
||||
getInitialState: getInitialRefImagesState,
|
||||
persistConfig: {
|
||||
migrate,
|
||||
migrate: (state) => zRefImagesState.parse(state),
|
||||
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
|
||||
},
|
||||
};
|
||||
|
||||
@@ -3,7 +3,6 @@ import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEnt
|
||||
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
zParameterCanvasCoherenceMode,
|
||||
zParameterCFGRescaleMultiplier,
|
||||
@@ -45,7 +44,7 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async
|
||||
}
|
||||
});
|
||||
|
||||
const zImageWithDims = z
|
||||
export const zImageWithDims = z
|
||||
.object({
|
||||
image_name: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
@@ -424,12 +423,13 @@ export const zCanvasEntityIdentifer = z.object({
|
||||
});
|
||||
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
|
||||
|
||||
export type LoRA = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ParameterLoRAModel;
|
||||
weight: number;
|
||||
};
|
||||
export const zLoRA = z.object({
|
||||
id: z.string(),
|
||||
isEnabled: z.boolean(),
|
||||
model: zServerValidatedModelIdentifierField,
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
});
|
||||
export type LoRA = z.infer<typeof zLoRA>;
|
||||
|
||||
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
|
||||
|
||||
@@ -574,11 +574,11 @@ export const zParamsState = z.object({
|
||||
export type ParamsState = z.infer<typeof zParamsState>;
|
||||
export const getInitialParamsState = (): ParamsState => ({
|
||||
maskBlur: 16,
|
||||
maskBlurMethod: 'box' as const,
|
||||
canvasCoherenceMode: 'Gaussian Blur' as const,
|
||||
maskBlurMethod: 'box',
|
||||
canvasCoherenceMode: 'Gaussian Blur',
|
||||
canvasCoherenceMinDenoise: 0,
|
||||
canvasCoherenceEdgeSize: 16,
|
||||
infillMethod: 'lama' as const,
|
||||
infillMethod: 'lama',
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
@@ -588,15 +588,15 @@ export const getInitialParamsState = (): ParamsState => ({
|
||||
img2imgStrength: 0.75,
|
||||
optimizedDenoisingEnabled: true,
|
||||
iterations: 1,
|
||||
scheduler: 'dpmpp_3m_k' as const,
|
||||
upscaleScheduler: 'kdpm_2' as const,
|
||||
scheduler: 'dpmpp_3m_k',
|
||||
upscaleScheduler: 'kdpm_2',
|
||||
upscaleCfgScale: 2,
|
||||
seed: 0,
|
||||
shouldRandomizeSeed: true,
|
||||
steps: 30,
|
||||
model: null,
|
||||
vae: null,
|
||||
vaePrecision: 'fp32' as const,
|
||||
vaePrecision: 'fp32',
|
||||
fluxVAE: null,
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
@@ -610,7 +610,7 @@ export const getInitialParamsState = (): ParamsState => ({
|
||||
refinerModel: null,
|
||||
refinerSteps: 20,
|
||||
refinerCFGScale: 7.5,
|
||||
refinerScheduler: 'euler' as const,
|
||||
refinerScheduler: 'euler',
|
||||
refinerPositiveAestheticScore: 6,
|
||||
refinerNegativeAestheticScore: 2.5,
|
||||
refinerStart: 0.8,
|
||||
@@ -653,7 +653,7 @@ export const zCanvasState = z.object({
|
||||
});
|
||||
export type CanvasState = z.infer<typeof zCanvasState>;
|
||||
export const getInitialCanvasState = (): CanvasState => ({
|
||||
_version: 3 as const,
|
||||
_version: 3,
|
||||
selectedEntityIdentifier: null,
|
||||
bookmarkedEntityIdentifier: null,
|
||||
inpaintMasks: { isHidden: false, entities: [] },
|
||||
@@ -663,9 +663,9 @@ export const getInitialCanvasState = (): CanvasState => ({
|
||||
bbox: {
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
|
||||
scaleMethod: 'auto' as const,
|
||||
scaleMethod: 'auto',
|
||||
scaledSize: { width: 512, height: 512 },
|
||||
modelBase: 'sd-1' as const,
|
||||
modelBase: 'sd-1',
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user