From aed9b1013ed61357c8a49201f2330ff29ad84352 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 25 Jul 2025 13:13:28 +1000 Subject: [PATCH] refactor(ui): use zod for all redux state --- invokeai/frontend/web/eslint.config.mjs | 4 ++ invokeai/frontend/web/src/app/store/store.ts | 18 +++-- invokeai/frontend/web/src/app/store/types.ts | 9 +-- .../frontend/web/src/app/types/invokeai.ts | 1 - .../features/changeBoardModal/store/slice.ts | 2 +- .../store/canvasSettingsSlice.ts | 4 +- .../controlLayers/store/canvasSlice.ts | 2 +- .../store/canvasStagingAreaSlice.ts | 29 ++++---- .../controlLayers/store/lorasSlice.ts | 18 +++-- .../controlLayers/store/paramsSlice.ts | 2 +- .../controlLayers/store/refImagesSlice.ts | 9 +-- .../src/features/controlLayers/store/types.ts | 36 +++++----- .../store/dynamicPromptsSlice.ts | 7 +- .../ImageMenuItemSendToUpscale.tsx | 3 +- .../features/gallery/store/gallerySlice.ts | 28 +++++--- .../src/features/gallery/store/types.test.ts | 13 ++++ .../web/src/features/gallery/store/types.ts | 62 ++++++++++------- .../web/src/features/imageActions/actions.ts | 2 +- .../store/modelManagerV2Slice.ts | 45 +++++++------ .../features/nodes/components/flow/Flow.tsx | 10 ++- .../src/features/nodes/store/nodesSlice.ts | 21 +++--- .../web/src/features/nodes/store/types.ts | 25 ++++--- .../nodes/store/workflowLibrarySlice.ts | 36 +++++----- .../nodes/store/workflowSettingsSlice.ts | 59 ++++++++-------- .../web/src/features/nodes/types/common.ts | 2 +- .../src/features/nodes/types/invocation.ts | 61 ++++++++++++++--- .../Upscale/ParamTileControlNetModel.tsx | 27 +++++++- .../hooks/useIsTooLargeToUpscale.ts | 12 ++-- .../features/parameters/store/upscaleSlice.ts | 67 ++++++++++++------- .../src/features/queue/store/queueSlice.ts | 15 +++-- .../UpscaleInitialImage.tsx | 7 +- .../stylePresets/store/stylePresetSlice.ts | 29 +++++--- .../src/features/stylePresets/store/types.ts | 6 -- .../src/features/system/store/configSlice.ts | 2 +- .../src/features/system/store/systemSlice.ts | 29 ++++---- .../web/src/features/system/store/types.ts | 35 +++++----- .../ui/layouts/UpscalingLaunchpadPanel.tsx | 3 +- .../web/src/features/ui/store/uiSlice.ts | 5 +- .../frontend/web/src/services/api/types.ts | 42 ++++++++++-- 39 files changed, 488 insertions(+), 299 deletions(-) create mode 100644 invokeai/frontend/web/src/features/gallery/store/types.test.ts delete mode 100644 invokeai/frontend/web/src/features/stylePresets/store/types.ts diff --git a/invokeai/frontend/web/eslint.config.mjs b/invokeai/frontend/web/eslint.config.mjs index 6449cfb627..0adc887ceb 100644 --- a/invokeai/frontend/web/eslint.config.mjs +++ b/invokeai/frontend/web/eslint.config.mjs @@ -197,6 +197,10 @@ export default [ importNames: ['isEqual'], message: 'Please use objectEquals from @observ33r/object-equals instead.', }, + { + name: 'zod/v3', + message: 'Import from zod instead.', + }, ], }, ], diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 96781cc934..ab38e1e583 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -128,28 +128,26 @@ const unserialize: UnserializeFunction = (data, key) => { try { const initialState = getInitialState(); - const parsed = JSON.parse(data); - // strip out old keys - const stripped = pick(deepClone(parsed), keys(initialState)); - // run (additive) migrations - const migrated = persistConfig.migrate(stripped); + const stripped = pick(deepClone(data), keys(initialState)); /* * Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep, * but that merges arrays by index and partial objects by key. Using an identity function as the customizer results * in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state. */ - const transformed = mergeWith(migrated, initialState, (objVal) => objVal); + const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal); + // run (additive) migrations + const migrated = persistConfig.migrate(unPersistDenylisted); log.debug( { - persistedData: parsed, - rehydratedData: transformed as JsonObject, - diff: diff(parsed, transformed) as JsonObject, + persistedData: data as JsonObject, + rehydratedData: migrated as JsonObject, + diff: diff(data, migrated) as JsonObject, }, `Rehydrated slice "${key}"` ); - state = transformed; + state = migrated; } catch (err) { log.warn( { error: serializeError(err as Error) }, diff --git a/invokeai/frontend/web/src/app/store/types.ts b/invokeai/frontend/web/src/app/store/types.ts index 8787c93e36..28b28e1889 100644 --- a/invokeai/frontend/web/src/app/store/types.ts +++ b/invokeai/frontend/web/src/app/store/types.ts @@ -1,4 +1,3 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ import type { Slice } from '@reduxjs/toolkit'; import type { UndoableOptions } from 'redux-undo'; import type { ZodType } from 'zod'; @@ -13,7 +12,7 @@ export type SliceConfig = { /** * The zod schema for the slice. */ - zSchema: ZodType>; + schema: ZodType>; /** * A function that returns the initial state of the slice. */ @@ -23,11 +22,13 @@ export type SliceConfig = { */ persistConfig?: { /** - * Migrate the state to the current version during rehydration. + * Migrate the state to the current version during rehydration. This method should throw an error if the migration + * fails. + * * @param state The rehydrated state. * @returns A correctly-shaped state. */ - migrate: (state: any) => StateFromSlice; + migrate: (state: unknown) => StateFromSlice; /** * Keys to omit from the persisted state. */ diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 02be210de7..6bb84e1f23 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -58,7 +58,6 @@ const zNumericalParameterConfig = z.object({ fineStep: z.number().default(8), coarseStep: z.number().default(64), }); -export type NumericalParameterConfig = z.infer; /** * Configuration options for the InvokeAI UI. diff --git a/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts b/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts index fc652ec44f..3f72720a42 100644 --- a/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts +++ b/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts @@ -35,6 +35,6 @@ export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoa export const changeBoardModalSliceConfig: SliceConfig = { slice, - zSchema: zChangeBoardModalState, + schema: zChangeBoardModalState, getInitialState, }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts index 3587d07c1c..8c62f8cb02 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts @@ -112,7 +112,7 @@ const getInitialState = (): CanvasSettingsState => ({ pressureSensitivity: true, ruleOfThirds: false, saveAllImagesToGallery: false, - stagingAreaAutoSwitch: 'switch_on_start' as const, + stagingAreaAutoSwitch: 'switch_on_start', }); const slice = createSlice({ @@ -209,7 +209,7 @@ export const { export const canvasSettingsSliceConfig: SliceConfig = { slice, - zSchema: zCanvasSettingsState, + schema: zCanvasSettingsState, getInitialState, persistConfig: { migrate: (state) => zCanvasSettingsState.parse(state), diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 668ae31b47..bb77beb8bd 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1720,7 +1720,7 @@ const canvasUndoableConfig: UndoableOptions = { export const canvasSliceConfig: SliceConfig = { slice, getInitialState: getInitialCanvasState, - zSchema: zCanvasState, + schema: zCanvasState, persistConfig: { migrate: (state) => zCanvasState.parse(state), }, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts index e98fc75b87..694abcda1c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts @@ -3,19 +3,25 @@ import { EMPTY_ARRAY } from 'app/store/constants'; import type { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { useMemo } from 'react'; import { queueApi } from 'services/api/endpoints/queue'; +import { assert } from 'tsafe'; import z from 'zod'; const zCanvasStagingAreaState = z.object({ - _version: z.literal(1).default(1), - canvasSessionId: z.string().default(() => getPrefixedId('canvas')), - canvasDiscardedQueueItems: z.array(z.number().int()).default(() => []), + _version: z.literal(1), + canvasSessionId: z.string(), + canvasDiscardedQueueItems: z.array(z.number().int()), }); type CanvasStagingAreaState = z.infer; -const getInitialState = (): CanvasStagingAreaState => zCanvasStagingAreaState.parse({}); +const getInitialState = (): CanvasStagingAreaState => ({ + _version: 1, + canvasSessionId: getPrefixedId('canvas'), + canvasDiscardedQueueItems: [], +}); const slice = createSlice({ name: 'canvasSession', @@ -48,18 +54,17 @@ export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions; export const canvasSessionSliceConfig: SliceConfig = { slice, - zSchema: zCanvasStagingAreaState, + schema: zCanvasStagingAreaState, getInitialState, persistConfig: { migrate: (state) => { - { - if (!('_version' in state)) { - state._version = 1; - state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas'); - } - - return zCanvasStagingAreaState.parse(state); + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas'); } + + return zCanvasStagingAreaState.parse(state); }, }, }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts index 5c85199c77..c4fc565146 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts @@ -2,14 +2,16 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolki import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; import { paramsReset } from 'features/controlLayers/store/paramsSlice'; -import type { LoRA } from 'features/controlLayers/store/types'; +import { type LoRA, zLoRA } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { LoRAModelConfig } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; +import z from 'zod'; -type LoRAsState = { - loras: LoRA[]; -}; +const zLoRAsState = z.object({ + loras: z.array(zLoRA), +}); +type LoRAsState = z.infer; const defaultLoRAConfig: Pick = { weight: 0.75, @@ -74,16 +76,12 @@ const slice = createSlice({ export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - return state; -}; - export const lorasSliceConfig: SliceConfig = { slice, + schema: zLoRAsState, getInitialState, persistConfig: { - migrate, + migrate: (state) => zLoRAsState.parse(state), }, }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 893809b009..ac5e4c1529 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -403,7 +403,7 @@ export const { export const paramsSliceConfig: SliceConfig = { slice, - zSchema: zParamsState, + schema: zParamsState, getInitialState: getInitialParamsState, persistConfig: { migrate: (state) => zParamsState.parse(state), diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index f52bc2a980..5876625085 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -266,17 +266,12 @@ export const { refImagesRecalled, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - return state; -}; - export const refImagesSliceConfig: SliceConfig = { slice, - zSchema: zRefImagesState, + schema: zRefImagesState, getInitialState: getInitialRefImagesState, persistConfig: { - migrate, + migrate: (state) => zRefImagesState.parse(state), persistDenylist: ['selectedEntityId', 'isPanelOpen'], }, }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 2607c396cb..3bed7b112e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -3,7 +3,6 @@ import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEnt import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers'; import type { ProgressImage } from 'features/nodes/types/common'; import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; -import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import { zParameterCanvasCoherenceMode, zParameterCFGRescaleMultiplier, @@ -45,7 +44,7 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async } }); -const zImageWithDims = z +export const zImageWithDims = z .object({ image_name: z.string(), width: z.number().int().positive(), @@ -424,12 +423,13 @@ export const zCanvasEntityIdentifer = z.object({ }); export type CanvasEntityIdentifier = { id: string; type: T }; -export type LoRA = { - id: string; - isEnabled: boolean; - model: ParameterLoRAModel; - weight: number; -}; +export const zLoRA = z.object({ + id: z.string(), + isEnabled: z.boolean(), + model: zServerValidatedModelIdentifierField, + weight: z.number().gte(-1).lte(2), +}); +export type LoRA = z.infer; export type EphemeralProgressImage = { sessionId: string; image: ProgressImage }; @@ -574,11 +574,11 @@ export const zParamsState = z.object({ export type ParamsState = z.infer; export const getInitialParamsState = (): ParamsState => ({ maskBlur: 16, - maskBlurMethod: 'box' as const, - canvasCoherenceMode: 'Gaussian Blur' as const, + maskBlurMethod: 'box', + canvasCoherenceMode: 'Gaussian Blur', canvasCoherenceMinDenoise: 0, canvasCoherenceEdgeSize: 16, - infillMethod: 'lama' as const, + infillMethod: 'lama', infillTileSize: 32, infillPatchmatchDownscaleSize: 1, infillColorValue: { r: 0, g: 0, b: 0, a: 1 }, @@ -588,15 +588,15 @@ export const getInitialParamsState = (): ParamsState => ({ img2imgStrength: 0.75, optimizedDenoisingEnabled: true, iterations: 1, - scheduler: 'dpmpp_3m_k' as const, - upscaleScheduler: 'kdpm_2' as const, + scheduler: 'dpmpp_3m_k', + upscaleScheduler: 'kdpm_2', upscaleCfgScale: 2, seed: 0, shouldRandomizeSeed: true, steps: 30, model: null, vae: null, - vaePrecision: 'fp32' as const, + vaePrecision: 'fp32', fluxVAE: null, seamlessXAxis: false, seamlessYAxis: false, @@ -610,7 +610,7 @@ export const getInitialParamsState = (): ParamsState => ({ refinerModel: null, refinerSteps: 20, refinerCFGScale: 7.5, - refinerScheduler: 'euler' as const, + refinerScheduler: 'euler', refinerPositiveAestheticScore: 6, refinerNegativeAestheticScore: 2.5, refinerStart: 0.8, @@ -653,7 +653,7 @@ export const zCanvasState = z.object({ }); export type CanvasState = z.infer; export const getInitialCanvasState = (): CanvasState => ({ - _version: 3 as const, + _version: 3, selectedEntityIdentifier: null, bookmarkedEntityIdentifier: null, inpaintMasks: { isHidden: false, entities: [] }, @@ -663,9 +663,9 @@ export const getInitialCanvasState = (): CanvasState => ({ bbox: { rect: { x: 0, y: 0, width: 512, height: 512 }, aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG), - scaleMethod: 'auto' as const, + scaleMethod: 'auto', scaledSize: { width: 512, height: 512 }, - modelBase: 'sd-1' as const, + modelBase: 'sd-1', }, }); diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts b/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts index 6490c75a57..8e07126122 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts +++ b/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts @@ -3,6 +3,8 @@ import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; import { buildZodTypeGuard } from 'common/util/zodUtils'; +import { isPlainObject } from 'es-toolkit'; +import { assert } from 'tsafe'; import { z } from 'zod'; const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']); @@ -19,7 +21,7 @@ const zDynamicPromptsState = z.object({ isLoading: z.boolean(), seedBehaviour: zSeedBehaviour, }); -type DynamicPromptsState = z.infer; +export type DynamicPromptsState = z.infer; const getInitialState = (): DynamicPromptsState => ({ _version: 1, @@ -69,10 +71,11 @@ export const { export const dynamicPromptsSliceConfig: SliceConfig = { slice, - zSchema: zDynamicPromptsState, + schema: zDynamicPromptsState, getInitialState, persistConfig: { migrate: (state) => { + assert(isPlainObject(state)); if (!('_version' in state)) { state._version = 1; } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemSendToUpscale.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemSendToUpscale.tsx index 9674b6c9c4..14b561c836 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemSendToUpscale.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemSendToUpscale.tsx @@ -1,5 +1,6 @@ import { MenuItem } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; +import { imageDTOToImageWithDims } from 'features/controlLayers/store/util'; import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext'; import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice'; import { toast } from 'features/toast/toast'; @@ -14,7 +15,7 @@ export const ImageMenuItemSendToUpscale = memo(() => { const imageDTO = useImageDTOContext(); const handleSendToCanvas = useCallback(() => { - dispatch(upscaleInitialImageChanged(imageDTO)); + dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO))); navigationApi.switchToTab('upscaling'); toast({ id: 'SENT_TO_CANVAS', diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index c788280bb2..c0794183a1 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -3,10 +3,19 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; import { uniq } from 'es-toolkit/compat'; import type { BoardRecordOrderBy } from 'services/api/types'; +import { assert } from 'tsafe'; -import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types'; +import { + type BoardId, + type ComparisonMode, + type GalleryState, + type GalleryView, + type OrderDir, + zGalleryState, +} from './types'; const getInitialState = (): GalleryState => ({ selection: [], @@ -192,19 +201,18 @@ export const { export const selectGallerySlice = (state: RootState) => state.gallery; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - return state; -}; - export const gallerySliceConfig: SliceConfig = { slice, + schema: zGalleryState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + return zGalleryState.parse(state); + }, persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'], }, }; diff --git a/invokeai/frontend/web/src/features/gallery/store/types.test.ts b/invokeai/frontend/web/src/features/gallery/store/types.test.ts new file mode 100644 index 0000000000..39a7a5602c --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/types.test.ts @@ -0,0 +1,13 @@ +import type { S } from 'services/api/types'; +import type { Equals } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +import type { BoardRecordOrderBy } from './types'; + +describe('Gallery Types', () => { + // Ensure zod types match OpenAPI types + test('BoardRecordOrderBy', () => { + assert>(); + }); +}); diff --git a/invokeai/frontend/web/src/features/gallery/store/types.ts b/invokeai/frontend/web/src/features/gallery/store/types.ts index b7a2ee4d11..959cd18301 100644 --- a/invokeai/frontend/web/src/features/gallery/store/types.ts +++ b/invokeai/frontend/web/src/features/gallery/store/types.ts @@ -1,31 +1,41 @@ -import type { BoardRecordOrderBy, ImageCategory } from 'services/api/types'; +import type { ImageCategory } from 'services/api/types'; +import z from 'zod'; + +const zGalleryView = z.enum(['images', 'assets']); +export type GalleryView = z.infer; +const zBoardId = z.union([z.literal('none'), z.intersection(z.string(), z.record(z.never(), z.never()))]); +export type BoardId = z.infer; +const zComparisonMode = z.enum(['slider', 'side-by-side', 'hover']); +export type ComparisonMode = z.infer; +const zComparisonFit = z.enum(['contain', 'fill']); +export type ComparisonFit = z.infer; +const zOrderDir = z.enum(['ASC', 'DESC']); +export type OrderDir = z.infer; +const zBoardRecordOrderBy = z.enum(['created_at', 'board_name']); +export type BoardRecordOrderBy = z.infer; export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other']; -export type GalleryView = 'images' | 'assets'; -export type BoardId = 'none' | (string & Record); -export type ComparisonMode = 'slider' | 'side-by-side' | 'hover'; -export type ComparisonFit = 'contain' | 'fill'; -export type OrderDir = 'ASC' | 'DESC'; +export const zGalleryState = z.object({ + selection: z.array(z.string()), + shouldAutoSwitch: z.boolean(), + autoAssignBoardOnClick: z.boolean(), + autoAddBoardId: zBoardId, + galleryImageMinimumWidth: z.number(), + selectedBoardId: zBoardId, + galleryView: zGalleryView, + boardSearchText: z.string(), + starredFirst: z.boolean(), + orderDir: zOrderDir, + searchTerm: z.string(), + alwaysShowImageSizeBadge: z.boolean(), + imageToCompare: z.string().nullable(), + comparisonMode: zComparisonMode, + comparisonFit: zComparisonFit, + shouldShowArchivedBoards: z.boolean(), + boardsListOrderBy: zBoardRecordOrderBy, + boardsListOrderDir: zOrderDir, +}); -export type GalleryState = { - selection: string[]; - shouldAutoSwitch: boolean; - autoAssignBoardOnClick: boolean; - autoAddBoardId: BoardId; - galleryImageMinimumWidth: number; - selectedBoardId: BoardId; - galleryView: GalleryView; - boardSearchText: string; - starredFirst: boolean; - orderDir: OrderDir; - searchTerm: string; - alwaysShowImageSizeBadge: boolean; - imageToCompare: string | null; - comparisonMode: ComparisonMode; - comparisonFit: ComparisonFit; - shouldShowArchivedBoards: boolean; - boardsListOrderBy: BoardRecordOrderBy; - boardsListOrderDir: OrderDir; -}; +export type GalleryState = z.infer; diff --git a/invokeai/frontend/web/src/features/imageActions/actions.ts b/invokeai/frontend/web/src/features/imageActions/actions.ts index 885b1105c6..c53d6dab23 100644 --- a/invokeai/frontend/web/src/features/imageActions/actions.ts +++ b/invokeai/frontend/web/src/features/imageActions/actions.ts @@ -58,7 +58,7 @@ export const setRegionalGuidanceReferenceImage = (arg: { export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => { const { imageDTO, dispatch } = arg; - dispatch(upscaleInitialImageChanged(imageDTO)); + dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO))); }; export const setNodeImageFieldImage = (arg: { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 15d1ea5bf8..7f20f21c22 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -2,19 +2,25 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; -import type { ModelType } from 'services/api/types'; +import { isPlainObject } from 'es-toolkit'; +import { zModelType } from 'features/nodes/types/common'; +import { assert } from 'tsafe'; +import z from 'zod'; -export type FilterableModelType = Exclude | 'refiner'; +const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner')); +export type FilterableModelType = z.infer; -type ModelManagerState = { - _version: 1; - selectedModelKey: string | null; - selectedModelMode: 'edit' | 'view'; - searchTerm: string; - filteredModelType: FilterableModelType | null; - scanPath: string | undefined; - shouldInstallInPlace: boolean; -}; +const zModelManagerState = z.object({ + _version: z.literal(1), + selectedModelKey: z.string().nullable(), + selectedModelMode: z.enum(['edit', 'view']), + searchTerm: z.string(), + filteredModelType: zFilterableModelType.nullable(), + scanPath: z.string().optional(), + shouldInstallInPlace: z.boolean(), +}); + +type ModelManagerState = z.infer; const getInitialState = (): ModelManagerState => ({ _version: 1, @@ -61,19 +67,18 @@ export const { shouldInstallInPlaceChanged, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - return state; -}; - export const modelManagerSliceConfig: SliceConfig = { slice, + schema: zModelManagerState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + return zModelManagerState.parse(state); + }, persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'], }, }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 1d1ad34f95..c6f29095e4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -14,7 +14,13 @@ import type { ReactFlowProps, ReactFlowState, } from '@xyflow/react'; -import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react'; +import { + Background, + ReactFlow, + SelectionMode, + useStore as useReactFlowStore, + useUpdateNodeInternals, +} from '@xyflow/react'; import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus'; import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish'; @@ -256,7 +262,7 @@ export const Flow = memo(() => { style={flowStyles} onPaneClick={handlePaneClick} deleteKeyCode={null} - selectionMode={selectionMode} + selectionMode={selectionMode === 'full' ? SelectionMode.Full : SelectionMode.Partial} elevateEdgesOnSelect nodeDragThreshold={1} noDragClassName={NO_DRAG_CLASS} diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 4d06d1c205..592af29d51 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -13,12 +13,13 @@ import type { import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react'; import type { SliceConfig } from 'app/store/types'; import { deepClone } from 'common/util/deepClone'; +import { isPlainObject } from 'es-toolkit'; import { addElement, removeElement, reparentElement, } from 'features/nodes/components/sidePanel/builder/form-manipulation'; -import type { NodesState } from 'features/nodes/store/types'; +import { type NodesState, zNodesState } from 'features/nodes/store/types'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { BoardFieldValue, @@ -127,6 +128,7 @@ import { import { atom, computed } from 'nanostores'; import type { MouseEvent } from 'react'; import type { UndoableOptions } from 'redux-undo'; +import { assert } from 'tsafe'; import type { z } from 'zod'; import type { PendingConnection, Templates } from './types'; @@ -760,14 +762,6 @@ export const { redo, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - return state; -}; - export const $cursorPos = atom(null); export const $templates = atom({}); export const $hasTemplates = computed($templates, (templates) => Object.keys(templates).length > 0); @@ -938,9 +932,16 @@ const reduxUndoOptions: UndoableOptions = { export const nodesSliceConfig: SliceConfig = { slice, + schema: zNodesState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + return zNodesState.parse(state); + }, }, undoableConfig: { reduxUndoOptions, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 8728b43e4c..a3391d7dae 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -1,7 +1,8 @@ import type { HandleType } from '@xyflow/react'; -import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field'; -import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; -import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field'; +import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; +import z from 'zod'; export type Templates = Record; export type NodeExecutionStates = Record; @@ -13,11 +14,13 @@ export type PendingConnection = { fieldTemplate: FieldInputTemplate | FieldOutputTemplate; }; -export type WorkflowMode = 'edit' | 'view'; - -export type NodesState = { - _version: 1; - nodes: AnyNode[]; - edges: AnyEdge[]; - formFieldInitialValues: Record; -} & Omit; +export const zWorkflowMode = z.enum(['edit', 'view']); +export type WorkflowMode = z.infer; +export const zNodesState = z.object({ + _version: z.literal(1), + nodes: z.array(zAnyNode), + edges: z.array(zAnyEdge), + formFieldInitialValues: z.record(z.string(), zStatefulFieldValue), + ...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape, +}); +export type NodesState = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts index c17e5f14d3..a1d8b89464 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts @@ -2,21 +2,29 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; -import type { WorkflowMode } from 'features/nodes/store/types'; +import { type WorkflowMode, zWorkflowMode } from 'features/nodes/store/types'; import type { WorkflowCategory } from 'features/nodes/types/workflow'; import { atom, computed } from 'nanostores'; -import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types'; +import { + type SQLiteDirection, + type WorkflowRecordOrderBy, + zSQLiteDirection, + zWorkflowRecordOrderBy, +} from 'services/api/types'; +import z from 'zod'; -export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published'; +const zWorkflowLibraryView = z.enum(['recent', 'yours', 'private', 'shared', 'defaults', 'published']); +export type WorkflowLibraryView = z.infer; -type WorkflowLibraryState = { - mode: WorkflowMode; - view: WorkflowLibraryView; - orderBy: WorkflowRecordOrderBy; - direction: SQLiteDirection; - searchTerm: string; - selectedTags: string[]; -}; +const zWorkflowLibraryState = z.object({ + mode: zWorkflowMode, + view: zWorkflowLibraryView, + orderBy: zWorkflowRecordOrderBy, + direction: zSQLiteDirection, + searchTerm: z.string(), + selectedTags: z.array(z.string()), +}); +type WorkflowLibraryState = z.infer; const getInitialState = (): WorkflowLibraryState => ({ mode: 'view', @@ -76,14 +84,12 @@ export const { workflowLibraryViewChanged, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => state; - export const workflowLibrarySliceConfig: SliceConfig = { slice, + schema: zWorkflowLibraryState, getInitialState, persistConfig: { - migrate, + migrate: (state) => zWorkflowLibraryState.parse(state), }, }; diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts index 8d9cd4d837..85b803acd4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts @@ -1,9 +1,10 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; -import { SelectionMode } from '@xyflow/react'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; import type { Selector } from 'react-redux'; +import { assert } from 'tsafe'; import z from 'zod'; export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']); @@ -12,23 +13,26 @@ export const zLayoutDirection = z.enum(['TB', 'LR']); type LayoutDirection = z.infer; export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']); type NodeAlignment = z.infer; +const zSelectionMode = z.enum(['partial', 'full']); -export type WorkflowSettingsState = { - _version: 1; - shouldShowMinimapPanel: boolean; - layeringStrategy: LayeringStrategy; - nodeSpacing: number; - layerSpacing: number; - layoutDirection: LayoutDirection; - shouldValidateGraph: boolean; - shouldAnimateEdges: boolean; - nodeAlignment: NodeAlignment; - nodeOpacity: number; - shouldSnapToGrid: boolean; - shouldColorEdges: boolean; - shouldShowEdgeLabels: boolean; - selectionMode: SelectionMode; -}; +const zWorkflowSettingsState = z.object({ + _version: z.literal(1), + shouldShowMinimapPanel: z.boolean(), + layeringStrategy: zLayeringStrategy, + nodeSpacing: z.number(), + layerSpacing: z.number(), + layoutDirection: zLayoutDirection, + shouldValidateGraph: z.boolean(), + shouldAnimateEdges: z.boolean(), + nodeAlignment: zNodeAlignment, + nodeOpacity: z.number(), + shouldSnapToGrid: z.boolean(), + shouldColorEdges: z.boolean(), + shouldShowEdgeLabels: z.boolean(), + selectionMode: zSelectionMode, +}); + +export type WorkflowSettingsState = z.infer; const getInitialState = (): WorkflowSettingsState => ({ _version: 1, @@ -44,7 +48,7 @@ const getInitialState = (): WorkflowSettingsState => ({ shouldColorEdges: true, shouldShowEdgeLabels: false, nodeOpacity: 1, - selectionMode: SelectionMode.Partial, + selectionMode: 'partial', }); const slice = createSlice({ @@ -88,7 +92,7 @@ const slice = createSlice({ state.nodeAlignment = action.payload; }, selectionModeChanged: (state, action: PayloadAction) => { - state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; + state.selectionMode = action.payload ? 'full' : 'partial'; }, }, }); @@ -109,19 +113,18 @@ export const { selectionModeChanged, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - return state; -}; - export const workflowSettingsSliceConfig: SliceConfig = { slice, + schema: zWorkflowSettingsState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + return zWorkflowSettingsState.parse(state); + }, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 1c9b4ec8ee..ea6ad790a0 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -92,7 +92,7 @@ export const zMainModelBase = z.enum([ ]); type MainModelBase = z.infer; export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success; -const zModelType = z.enum([ +export const zModelType = z.enum([ 'main', 'vae', 'lora', diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index c2435ecfb8..96a52a2377 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -43,7 +43,7 @@ export const zNotesNodeData = z.object({ isOpen: z.boolean(), notes: z.string(), }); -const _zCurrentImageNodeData = z.object({ +const zCurrentImageNodeData = z.object({ id: z.string().trim().min(1), type: z.literal('current_image'), label: z.string(), @@ -52,12 +52,35 @@ const _zCurrentImageNodeData = z.object({ export type NotesNodeData = z.infer; export type InvocationNodeData = z.infer; -type CurrentImageNodeData = z.infer; +type CurrentImageNodeData = z.infer; -export type InvocationNode = Node; -export type NotesNode = Node; -export type CurrentImageNode = Node; -export type AnyNode = InvocationNode | NotesNode | CurrentImageNode; +const zInvocationNodeValidationSchema = z.looseObject({ + type: z.literal('invocation'), + data: zInvocationNodeData, +}); +const zInvocationNode = z.custom>( + (val) => zInvocationNodeValidationSchema.safeParse(val).success +); +export type InvocationNode = z.infer; + +const zNotesNodeValidationSchema = z.looseObject({ + type: z.literal('notes'), + data: zNotesNodeData, +}); +const zNotesNode = z.custom>((val) => zNotesNodeValidationSchema.safeParse(val).success); +export type NotesNode = z.infer; + +const zCurrentImageNodeValidationSchema = z.looseObject({ + type: z.literal('current_image'), + data: zCurrentImageNodeData, +}); +const zCurrentImageNode = z.custom>( + (val) => zCurrentImageNodeValidationSchema.safeParse(val).success +); +export type CurrentImageNode = z.infer; + +export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]); +export type AnyNode = z.infer; export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => Boolean(node && node.type === 'invocation'); @@ -83,13 +106,29 @@ export type NodeExecutionState = z.infer; // #endregion // #region Edges -const _zInvocationNodeEdgeCollapsedData = z.object({ +const zDefaultInvocationNodeEdgeValidationSchema = z.looseObject({ + type: z.literal('default'), +}); +const zDefaultInvocationNodeEdge = z.custom, 'default'>>( + (val) => zDefaultInvocationNodeEdgeValidationSchema.safeParse(val).success +); +export type DefaultInvocationNodeEdge = z.infer; + +const zInvocationNodeEdgeCollapsedData = z.object({ count: z.number().int().min(1), }); -type InvocationNodeEdgeCollapsedData = z.infer; -export type DefaultInvocationNodeEdge = Edge, 'default'>; -export type CollapsedInvocationNodeEdge = Edge; -export type AnyEdge = DefaultInvocationNodeEdge | CollapsedInvocationNodeEdge; +const zInvocationNodeEdgeCollapsedValidationSchema = z.looseObject({ + type: z.literal('default'), + data: zInvocationNodeEdgeCollapsedData, +}); +type InvocationNodeEdgeCollapsedData = z.infer; + +const zCollapsedInvocationNodeEdge = z.custom>( + (val) => zInvocationNodeEdgeCollapsedValidationSchema.safeParse(val).success +); +export type CollapsedInvocationNodeEdge = z.infer; +export const zAnyEdge = z.union([zDefaultInvocationNodeEdge, zCollapsedInvocationNodeEdge]); +export type AnyEdge = z.infer; // #endregion export const isBatchNodeType = (type: string) => diff --git a/invokeai/frontend/web/src/features/parameters/components/Upscale/ParamTileControlNetModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Upscale/ParamTileControlNetModel.tsx index 048b666b37..07a5395e03 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Upscale/ParamTileControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Upscale/ParamTileControlNetModel.tsx @@ -1,4 +1,5 @@ import { FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { selectBase } from 'features/controlLayers/store/paramsSlice'; @@ -6,13 +7,35 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker'; import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; import { useControlNetModels } from 'services/api/hooks/modelsByType'; -import type { ControlNetModelConfig } from 'services/api/types'; +import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types'; + +const selectTileControlNetModelConfig = createSelector( + selectModelConfigsQuery, + selectTileControlNetModel, + (modelConfigs, modelIdentifierField) => { + if (!modelConfigs.data) { + return null; + } + if (!modelIdentifierField) { + return null; + } + const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key); + if (!modelConfig) { + return null; + } + if (!isControlNetModelConfig(modelConfig)) { + return null; + } + return modelConfig; + } +); const ParamTileControlNetModel = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const tileControlNetModel = useAppSelector(selectTileControlNetModel); + const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig); const currentBaseModel = useAppSelector(selectBase); const [modelConfigs, { isLoading }] = useControlNetModels(); diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useIsTooLargeToUpscale.ts b/invokeai/frontend/web/src/features/parameters/hooks/useIsTooLargeToUpscale.ts index 8bc8859af5..7428d50c9b 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useIsTooLargeToUpscale.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useIsTooLargeToUpscale.ts @@ -1,21 +1,21 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; +import type { ImageWithDims } from 'features/controlLayers/store/types'; import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice'; import { selectConfigSlice } from 'features/system/store/configSlice'; import { useMemo } from 'react'; -import type { ImageDTO } from 'services/api/types'; -const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) => +const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) => createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => { const { upscaleModel, scale } = upscale; const { maxUpscaleDimension } = config; - if (!maxUpscaleDimension || !upscaleModel || !imageDTO) { + if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) { // When these are missing, another warning will be shown return false; } - const { width, height } = imageDTO; + const { width, height } = imageWithDims; const maxPixels = maxUpscaleDimension ** 2; const upscaledPixels = width * scale * height * scale; @@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) => return upscaledPixels > maxPixels; }); -export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => { - const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]); +export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => { + const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]); return useAppSelector(selectIsTooLargeToUpscale); }; diff --git a/invokeai/frontend/web/src/features/parameters/store/upscaleSlice.ts b/invokeai/frontend/web/src/features/parameters/store/upscaleSlice.ts index 8bf3ad58d5..0bb47b6f2e 100644 --- a/invokeai/frontend/web/src/features/parameters/store/upscaleSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/upscaleSlice.ts @@ -2,24 +2,32 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; +import type { ImageWithDims } from 'features/controlLayers/store/types'; +import { zImageWithDims } from 'features/controlLayers/store/types'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas'; -import type { ControlNetModelConfig, ImageDTO } from 'services/api/types'; +import type { ControlNetModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; +import z from 'zod'; -export interface UpscaleState { - _version: 1; - upscaleModel: ParameterSpandrelImageToImageModel | null; - upscaleInitialImage: ImageDTO | null; - structure: number; - creativity: number; - tileControlnetModel: ControlNetModelConfig | null; - scale: number; - postProcessingModel: ParameterSpandrelImageToImageModel | null; - tileSize: number; - tileOverlap: number; -} +const zUpscaleState = z.object({ + _version: z.literal(2), + upscaleModel: zModelIdentifierField.nullable(), + upscaleInitialImage: zImageWithDims.nullable(), + structure: z.number(), + creativity: z.number(), + tileControlnetModel: zModelIdentifierField.nullable(), + scale: z.number(), + postProcessingModel: zModelIdentifierField.nullable(), + tileSize: z.number(), + tileOverlap: z.number(), +}); + +export type UpscaleState = z.infer; const getInitialState = (): UpscaleState => ({ - _version: 1, + _version: 2, upscaleModel: null, upscaleInitialImage: null, structure: 0, @@ -38,7 +46,7 @@ const slice = createSlice({ upscaleModelChanged: (state, action: PayloadAction) => { state.upscaleModel = action.payload; }, - upscaleInitialImageChanged: (state, action: PayloadAction) => { + upscaleInitialImageChanged: (state, action: PayloadAction) => { state.upscaleInitialImage = action.payload; }, structureChanged: (state, action: PayloadAction) => { @@ -77,19 +85,30 @@ export const { tileOverlapChanged, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - return state; -}; - export const upscaleSliceConfig: SliceConfig = { slice, + schema: zUpscaleState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + if (state._version === 1) { + state._version = 2; + // Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims + if (state.upscaleInitialImage) { + const { image_name, width, height } = state.upscaleInitialImage; + state.upscaleInitialImage = { + image_name, + width, + height, + }; + } + } + return zUpscaleState.parse(state); + }, }, }; diff --git a/invokeai/frontend/web/src/features/queue/store/queueSlice.ts b/invokeai/frontend/web/src/features/queue/store/queueSlice.ts index a80ee7da1b..b6789400f8 100644 --- a/invokeai/frontend/web/src/features/queue/store/queueSlice.ts +++ b/invokeai/frontend/web/src/features/queue/store/queueSlice.ts @@ -2,13 +2,15 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import z from 'zod'; -interface QueueState { - listCursor: number | undefined; - listPriority: number | undefined; - selectedQueueItem: string | undefined; - resumeProcessorOnEnqueue: boolean; -} +const zQueueState = z.object({ + listCursor: z.number().optional(), + listPriority: z.number().optional(), + selectedQueueItem: z.string().optional(), + resumeProcessorOnEnqueue: z.boolean(), +}); +type QueueState = z.infer; const getInitialState = (): QueueState => ({ listCursor: undefined, @@ -38,6 +40,7 @@ export const { listCursorChanged, listPriorityChanged, listParamsReset } = slice export const queueSliceConfig: SliceConfig = { slice, + schema: zQueueState, getInitialState, }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleInitialImage.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleInitialImage.tsx index 117e1e4159..7e1f1a433e 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleInitialImage.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleInitialImage.tsx @@ -1,6 +1,7 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { UploadImageIconButton } from 'common/hooks/useImageUploadButton'; +import { imageDTOToImageWithDims } from 'features/controlLayers/store/util'; import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd'; import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd'; import { DndDropTarget } from 'features/dnd/DndDropTarget'; @@ -10,11 +11,13 @@ import { selectUpscaleInitialImage, upscaleInitialImageChanged } from 'features/ import { t } from 'i18next'; import { useCallback, useMemo } from 'react'; import { PiArrowCounterClockwiseBold } from 'react-icons/pi'; +import { useImageDTO } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; export const UpscaleInitialImage = () => { const dispatch = useAppDispatch(); - const imageDTO = useAppSelector(selectUpscaleInitialImage); + const upscaleInitialImage = useAppSelector(selectUpscaleInitialImage); + const imageDTO = useImageDTO(upscaleInitialImage?.image_name); const dndTargetData = useMemo( () => setUpscaleInitialImageDndTarget.getData(), [] @@ -26,7 +29,7 @@ export const UpscaleInitialImage = () => { const onUpload = useCallback( (imageDTO: ImageDTO) => { - dispatch(upscaleInitialImageChanged(imageDTO)); + dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO))); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/stylePresets/store/stylePresetSlice.ts b/invokeai/frontend/web/src/features/stylePresets/store/stylePresetSlice.ts index 0b3e3f62bc..6434fc614d 100644 --- a/invokeai/frontend/web/src/features/stylePresets/store/stylePresetSlice.ts +++ b/invokeai/frontend/web/src/features/stylePresets/store/stylePresetSlice.ts @@ -2,11 +2,21 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; import { paramsReset } from 'features/controlLayers/store/paramsSlice'; import { atom } from 'nanostores'; import { stylePresetsApi } from 'services/api/endpoints/stylePresets'; +import { assert } from 'tsafe'; +import z from 'zod'; -import type { StylePresetState } from './types'; +const zStylePresetState = z.object({ + activeStylePresetId: z.string().nullable(), + searchTerm: z.string(), + viewMode: z.boolean(), + showPromptPreviews: z.boolean(), +}); + +type StylePresetState = z.infer; const getInitialState = (): StylePresetState => ({ activeStylePresetId: null, @@ -60,19 +70,18 @@ const slice = createSlice({ export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - return state; -}; - export const stylePresetSliceConfig: SliceConfig = { slice, + schema: zStylePresetState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + return zStylePresetState.parse(state); + }, }, }; diff --git a/invokeai/frontend/web/src/features/stylePresets/store/types.ts b/invokeai/frontend/web/src/features/stylePresets/store/types.ts deleted file mode 100644 index d156808384..0000000000 --- a/invokeai/frontend/web/src/features/stylePresets/store/types.ts +++ /dev/null @@ -1,6 +0,0 @@ -export type StylePresetState = { - activeStylePresetId: string | null; - searchTerm: string; - viewMode: boolean; - showPromptPreviews: boolean; -}; diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts index e47adcaed4..5123107be8 100644 --- a/invokeai/frontend/web/src/features/system/store/configSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts @@ -32,7 +32,7 @@ export const { configChanged } = slice.actions; export const configSliceConfig: SliceConfig = { slice, - zSchema: zConfigState, + schema: zConfigState, getInitialState, }; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 6bb615a6cc..1cc22c8dea 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -5,9 +5,11 @@ import { zLogNamespace } from 'app/logging/logger'; import { EMPTY_ARRAY } from 'app/store/constants'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; import { uniq } from 'es-toolkit/compat'; +import { assert } from 'tsafe'; -import type { Language, SystemState } from './types'; +import { type Language, type SystemState, zSystemState } from './types'; const getInitialState = (): SystemState => ({ _version: 2, @@ -92,23 +94,22 @@ export const { setShouldHighlightFocusedRegions, } = slice.actions; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrate = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - if (state._version === 1) { - state.language = (state as SystemState).language.replace('_', '-'); - state._version = 2; - } - return state; -}; - export const systemSliceConfig: SliceConfig = { slice, + schema: zSystemState, getInitialState, persistConfig: { - migrate, + migrate: (state) => { + assert(isPlainObject(state)); + if (!('_version' in state)) { + state._version = 1; + } + if (state._version === 1) { + state.language = (state as SystemState).language.replace('_', '-'); + state._version = 2; + } + return zSystemState.parse(state); + }, }, }; diff --git a/invokeai/frontend/web/src/features/system/store/types.ts b/invokeai/frontend/web/src/features/system/store/types.ts index d2d9e456e1..3eaf8628c6 100644 --- a/invokeai/frontend/web/src/features/system/store/types.ts +++ b/invokeai/frontend/web/src/features/system/store/types.ts @@ -1,4 +1,4 @@ -import type { LogLevel, LogNamespace } from 'app/logging/logger'; +import { zLogLevel, zLogNamespace } from 'app/logging/logger'; import { z } from 'zod'; const zLanguage = z.enum([ @@ -29,19 +29,20 @@ const zLanguage = z.enum([ export type Language = z.infer; export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success; -export interface SystemState { - _version: 2; - shouldConfirmOnDelete: boolean; - shouldAntialiasProgressImage: boolean; - shouldConfirmOnNewSession: boolean; - language: Language; - shouldUseNSFWChecker: boolean; - shouldUseWatermarker: boolean; - shouldEnableInformationalPopovers: boolean; - shouldEnableModelDescriptions: boolean; - logIsEnabled: boolean; - logLevel: LogLevel; - logNamespaces: LogNamespace[]; - shouldShowInvocationProgressDetail: boolean; - shouldHighlightFocusedRegions: boolean; -} +export const zSystemState = z.object({ + _version: z.literal(2), + shouldConfirmOnDelete: z.boolean(), + shouldAntialiasProgressImage: z.boolean(), + shouldConfirmOnNewSession: z.boolean(), + language: zLanguage, + shouldUseNSFWChecker: z.boolean(), + shouldUseWatermarker: z.boolean(), + shouldEnableInformationalPopovers: z.boolean(), + shouldEnableModelDescriptions: z.boolean(), + logIsEnabled: z.boolean(), + logLevel: zLogLevel, + logNamespaces: z.array(zLogNamespace), + shouldShowInvocationProgressDetail: z.boolean(), + shouldHighlightFocusedRegions: z.boolean(), +}); +export type SystemState = z.infer; diff --git a/invokeai/frontend/web/src/features/ui/layouts/UpscalingLaunchpadPanel.tsx b/invokeai/frontend/web/src/features/ui/layouts/UpscalingLaunchpadPanel.tsx index 496663064b..747b81a618 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/UpscalingLaunchpadPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/UpscalingLaunchpadPanel.tsx @@ -1,6 +1,7 @@ import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; +import { imageDTOToImageWithDims } from 'features/controlLayers/store/util'; import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd'; import { DndDropTarget } from 'features/dnd/DndDropTarget'; import { @@ -37,7 +38,7 @@ export const UpscalingLaunchpadPanel = memo(() => { const onUpload = useCallback( (imageDTO: ImageDTO) => { - dispatch(upscaleInitialImageChanged(imageDTO)); + dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO))); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 35ead003ac..f0e88cc8fe 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -2,6 +2,8 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types'; +import { isPlainObject } from 'es-toolkit'; +import { assert } from 'tsafe'; import { getInitialUIState, type UIState, zUIState } from './uiTypes'; @@ -87,10 +89,11 @@ export const selectUiSlice = (state: RootState) => state.ui; export const uiSliceConfig: SliceConfig = { slice, - zSchema: zUIState, + schema: zUIState, getInitialState: getInitialUIState, persistConfig: { migrate: (state) => { + assert(isPlainObject(state)); if (!('_version' in state)) { state._version = 1; } diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 725b0ebfe8..98f0d98034 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -1,6 +1,9 @@ import type { Dimensions } from 'features/controlLayers/store/types'; import type { components, paths } from 'services/api/schema'; +import type { Equals } from 'tsafe'; +import { assert } from 'tsafe'; import type { JsonObject, SetRequired } from 'type-fest'; +import z from 'zod'; export type S = components['schemas']; @@ -33,10 +36,36 @@ export type InvocationJSONSchemaExtra = S['UIConfigBase']; export type AppVersion = S['AppVersion']; export type AppConfig = S['AppConfig']; +const zResourceOrigin = z.enum(['internal', 'external']); +type ResourceOrigin = z.infer; +assert>(); +const zImageCategory = z.enum(['general', 'mask', 'control', 'user', 'other']); +export type ImageCategory = z.infer; +assert>(); + // Images -export type ImageDTO = S['ImageDTO']; +const _zImageDTO = z.object({ + image_name: z.string(), + image_url: z.string(), + thumbnail_url: z.string(), + image_origin: zResourceOrigin, + image_category: zImageCategory, + width: z.number().int().gt(0), + height: z.number().int().gt(0), + created_at: z.string(), + updated_at: z.string(), + deleted_at: z.string().nullish(), + is_intermediate: z.boolean(), + session_id: z.string().nullish(), + node_id: z.string().nullish(), + starred: z.boolean(), + has_workflow: z.boolean(), + board_id: z.string().nullish(), +}); +export type ImageDTO = z.infer; +assert>(); + export type BoardDTO = S['BoardDTO']; -export type ImageCategory = S['ImageCategory']; export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_']; // Models @@ -298,8 +327,13 @@ export type ModelInstallStatus = S['InstallStatus']; export type Graph = S['Graph']; export type NonNullableGraph = SetRequired; export type Batch = S['Batch']; -export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy']; -export type SQLiteDirection = S['SQLiteDirection']; +export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at']); +export type WorkflowRecordOrderBy = z.infer; +assert>(); + +export const zSQLiteDirection = z.enum(['ASC', 'DESC']); +export type SQLiteDirection = z.infer; +assert>(); export type WorkflowRecordListItemWithThumbnailDTO = S['WorkflowRecordListItemWithThumbnailDTO']; type KeysOfUnion = T extends T ? keyof T : never;