refactor(ui): use zod for all redux state

This commit is contained in:
psychedelicious
2025-07-25 13:13:28 +10:00
parent 6962536b4a
commit aed9b1013e
39 changed files with 488 additions and 299 deletions

View File

@@ -197,6 +197,10 @@ export default [
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'zod/v3',
message: 'Import from zod instead.',
},
],
},
],

View File

@@ -128,28 +128,26 @@ const unserialize: UnserializeFunction = (data, key) => {
try {
const initialState = getInitialState();
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(deepClone(parsed), keys(initialState));
// run (additive) migrations
const migrated = persistConfig.migrate(stripped);
const stripped = pick(deepClone(data), keys(initialState));
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
*/
const transformed = mergeWith(migrated, initialState, (objVal) => objVal);
const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal);
// run (additive) migrations
const migrated = persistConfig.migrate(unPersistDenylisted);
log.debug(
{
persistedData: parsed,
rehydratedData: transformed as JsonObject,
diff: diff(parsed, transformed) as JsonObject,
persistedData: data as JsonObject,
rehydratedData: migrated as JsonObject,
diff: diff(data, migrated) as JsonObject,
},
`Rehydrated slice "${key}"`
);
state = transformed;
state = migrated;
} catch (err) {
log.warn(
{ error: serializeError(err as Error) },

View File

@@ -1,4 +1,3 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import type { Slice } from '@reduxjs/toolkit';
import type { UndoableOptions } from 'redux-undo';
import type { ZodType } from 'zod';
@@ -13,7 +12,7 @@ export type SliceConfig<T extends Slice> = {
/**
* The zod schema for the slice.
*/
zSchema: ZodType<StateFromSlice<T>>;
schema: ZodType<StateFromSlice<T>>;
/**
* A function that returns the initial state of the slice.
*/
@@ -23,11 +22,13 @@ export type SliceConfig<T extends Slice> = {
*/
persistConfig?: {
/**
* Migrate the state to the current version during rehydration.
* Migrate the state to the current version during rehydration. This method should throw an error if the migration
* fails.
*
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: any) => StateFromSlice<T>;
migrate: (state: unknown) => StateFromSlice<T>;
/**
* Keys to omit from the persisted state.
*/

View File

@@ -58,7 +58,6 @@ const zNumericalParameterConfig = z.object({
fineStep: z.number().default(8),
coarseStep: z.number().default(64),
});
export type NumericalParameterConfig = z.infer<typeof zNumericalParameterConfig>;
/**
* Configuration options for the InvokeAI UI.

View File

@@ -35,6 +35,6 @@ export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoa
export const changeBoardModalSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zChangeBoardModalState,
schema: zChangeBoardModalState,
getInitialState,
};

View File

@@ -112,7 +112,7 @@ const getInitialState = (): CanvasSettingsState => ({
pressureSensitivity: true,
ruleOfThirds: false,
saveAllImagesToGallery: false,
stagingAreaAutoSwitch: 'switch_on_start' as const,
stagingAreaAutoSwitch: 'switch_on_start',
});
const slice = createSlice({
@@ -209,7 +209,7 @@ export const {
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zCanvasSettingsState,
schema: zCanvasSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => zCanvasSettingsState.parse(state),

View File

@@ -1720,7 +1720,7 @@ const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
export const canvasSliceConfig: SliceConfig<typeof slice> = {
slice,
getInitialState: getInitialCanvasState,
zSchema: zCanvasState,
schema: zCanvasState,
persistConfig: {
migrate: (state) => zCanvasState.parse(state),
},

View File

@@ -3,19 +3,25 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import type { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
import z from 'zod';
const zCanvasStagingAreaState = z.object({
_version: z.literal(1).default(1),
canvasSessionId: z.string().default(() => getPrefixedId('canvas')),
canvasDiscardedQueueItems: z.array(z.number().int()).default(() => []),
_version: z.literal(1),
canvasSessionId: z.string(),
canvasDiscardedQueueItems: z.array(z.number().int()),
});
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
const getInitialState = (): CanvasStagingAreaState => zCanvasStagingAreaState.parse({});
const getInitialState = (): CanvasStagingAreaState => ({
_version: 1,
canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [],
});
const slice = createSlice({
name: 'canvasSession',
@@ -48,18 +54,17 @@ export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zCanvasStagingAreaState,
schema: zCanvasStagingAreaState,
getInitialState,
persistConfig: {
migrate: (state) => {
{
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return zCanvasStagingAreaState.parse(state);
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return zCanvasStagingAreaState.parse(state);
},
},
};

View File

@@ -2,14 +2,16 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolki
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
type LoRAsState = {
loras: LoRA[];
};
const zLoRAsState = z.object({
loras: z.array(zLoRA),
});
type LoRAsState = z.infer<typeof zLoRAsState>;
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
@@ -74,16 +76,12 @@ const slice = createSlice({
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zLoRAsState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => zLoRAsState.parse(state),
},
};

View File

@@ -403,7 +403,7 @@ export const {
export const paramsSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zParamsState,
schema: zParamsState,
getInitialState: getInitialParamsState,
persistConfig: {
migrate: (state) => zParamsState.parse(state),

View File

@@ -266,17 +266,12 @@ export const {
refImagesRecalled,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zRefImagesState,
schema: zRefImagesState,
getInitialState: getInitialRefImagesState,
persistConfig: {
migrate,
migrate: (state) => zRefImagesState.parse(state),
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
},
};

View File

@@ -3,7 +3,6 @@ import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEnt
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier,
@@ -45,7 +44,7 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async
}
});
const zImageWithDims = z
export const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
@@ -424,12 +423,13 @@ export const zCanvasEntityIdentifer = z.object({
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type LoRA = {
id: string;
isEnabled: boolean;
model: ParameterLoRAModel;
weight: number;
};
export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zServerValidatedModelIdentifierField,
weight: z.number().gte(-1).lte(2),
});
export type LoRA = z.infer<typeof zLoRA>;
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
@@ -574,11 +574,11 @@ export const zParamsState = z.object({
export type ParamsState = z.infer<typeof zParamsState>;
export const getInitialParamsState = (): ParamsState => ({
maskBlur: 16,
maskBlurMethod: 'box' as const,
canvasCoherenceMode: 'Gaussian Blur' as const,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama' as const,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
@@ -588,15 +588,15 @@ export const getInitialParamsState = (): ParamsState => ({
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k' as const,
upscaleScheduler: 'kdpm_2' as const,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
vaePrecision: 'fp32' as const,
vaePrecision: 'fp32',
fluxVAE: null,
seamlessXAxis: false,
seamlessYAxis: false,
@@ -610,7 +610,7 @@ export const getInitialParamsState = (): ParamsState => ({
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler' as const,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
@@ -653,7 +653,7 @@ export const zCanvasState = z.object({
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const getInitialCanvasState = (): CanvasState => ({
_version: 3 as const,
_version: 3,
selectedEntityIdentifier: null,
bookmarkedEntityIdentifier: null,
inpaintMasks: { isHidden: false, entities: [] },
@@ -663,9 +663,9 @@ export const getInitialCanvasState = (): CanvasState => ({
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
scaleMethod: 'auto' as const,
scaleMethod: 'auto',
scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1' as const,
modelBase: 'sd-1',
},
});

View File

@@ -3,6 +3,8 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
@@ -19,7 +21,7 @@ const zDynamicPromptsState = z.object({
isLoading: z.boolean(),
seedBehaviour: zSeedBehaviour,
});
type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
export type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
const getInitialState = (): DynamicPromptsState => ({
_version: 1,
@@ -69,10 +71,11 @@ export const {
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zDynamicPromptsState,
schema: zDynamicPromptsState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}

View File

@@ -1,5 +1,6 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
@@ -14,7 +15,7 @@ export const ImageMenuItemSendToUpscale = memo(() => {
const imageDTO = useImageDTOContext();
const handleSendToCanvas = useCallback(() => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
navigationApi.switchToTab('upscaling');
toast({
id: 'SENT_TO_CANVAS',

View File

@@ -3,10 +3,19 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { uniq } from 'es-toolkit/compat';
import type { BoardRecordOrderBy } from 'services/api/types';
import { assert } from 'tsafe';
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
import {
type BoardId,
type ComparisonMode,
type GalleryState,
type GalleryView,
type OrderDir,
zGalleryState,
} from './types';
const getInitialState = (): GalleryState => ({
selection: [],
@@ -192,19 +201,18 @@ export const {
export const selectGallerySlice = (state: RootState) => state.gallery;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const gallerySliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zGalleryState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zGalleryState.parse(state);
},
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
},
};

View File

@@ -0,0 +1,13 @@
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { BoardRecordOrderBy } from './types';
describe('Gallery Types', () => {
// Ensure zod types match OpenAPI types
test('BoardRecordOrderBy', () => {
assert<Equals<BoardRecordOrderBy, S['BoardRecordOrderBy']>>();
});
});

View File

@@ -1,31 +1,41 @@
import type { BoardRecordOrderBy, ImageCategory } from 'services/api/types';
import type { ImageCategory } from 'services/api/types';
import z from 'zod';
const zGalleryView = z.enum(['images', 'assets']);
export type GalleryView = z.infer<typeof zGalleryView>;
const zBoardId = z.union([z.literal('none'), z.intersection(z.string(), z.record(z.never(), z.never()))]);
export type BoardId = z.infer<typeof zBoardId>;
const zComparisonMode = z.enum(['slider', 'side-by-side', 'hover']);
export type ComparisonMode = z.infer<typeof zComparisonMode>;
const zComparisonFit = z.enum(['contain', 'fill']);
export type ComparisonFit = z.infer<typeof zComparisonFit>;
const zOrderDir = z.enum(['ASC', 'DESC']);
export type OrderDir = z.infer<typeof zOrderDir>;
const zBoardRecordOrderBy = z.enum(['created_at', 'board_name']);
export type BoardRecordOrderBy = z.infer<typeof zBoardRecordOrderBy>;
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
export type ComparisonFit = 'contain' | 'fill';
export type OrderDir = 'ASC' | 'DESC';
export const zGalleryState = z.object({
selection: z.array(z.string()),
shouldAutoSwitch: z.boolean(),
autoAssignBoardOnClick: z.boolean(),
autoAddBoardId: zBoardId,
galleryImageMinimumWidth: z.number(),
selectedBoardId: zBoardId,
galleryView: zGalleryView,
boardSearchText: z.string(),
starredFirst: z.boolean(),
orderDir: zOrderDir,
searchTerm: z.string(),
alwaysShowImageSizeBadge: z.boolean(),
imageToCompare: z.string().nullable(),
comparisonMode: zComparisonMode,
comparisonFit: zComparisonFit,
shouldShowArchivedBoards: z.boolean(),
boardsListOrderBy: zBoardRecordOrderBy,
boardsListOrderDir: zOrderDir,
});
export type GalleryState = {
selection: string[];
shouldAutoSwitch: boolean;
autoAssignBoardOnClick: boolean;
autoAddBoardId: BoardId;
galleryImageMinimumWidth: number;
selectedBoardId: BoardId;
galleryView: GalleryView;
boardSearchText: string;
starredFirst: boolean;
orderDir: OrderDir;
searchTerm: string;
alwaysShowImageSizeBadge: boolean;
imageToCompare: string | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
shouldShowArchivedBoards: boolean;
boardsListOrderBy: BoardRecordOrderBy;
boardsListOrderDir: OrderDir;
};
export type GalleryState = z.infer<typeof zGalleryState>;

View File

@@ -58,7 +58,7 @@ export const setRegionalGuidanceReferenceImage = (arg: {
export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
};
export const setNodeImageFieldImage = (arg: {

View File

@@ -2,19 +2,25 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { ModelType } from 'services/api/types';
import { isPlainObject } from 'es-toolkit';
import { zModelType } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner';
const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
export type FilterableModelType = z.infer<typeof zFilterableModelType>;
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view';
searchTerm: string;
filteredModelType: FilterableModelType | null;
scanPath: string | undefined;
shouldInstallInPlace: boolean;
};
const zModelManagerState = z.object({
_version: z.literal(1),
selectedModelKey: z.string().nullable(),
selectedModelMode: z.enum(['edit', 'view']),
searchTerm: z.string(),
filteredModelType: zFilterableModelType.nullable(),
scanPath: z.string().optional(),
shouldInstallInPlace: z.boolean(),
});
type ModelManagerState = z.infer<typeof zModelManagerState>;
const getInitialState = (): ModelManagerState => ({
_version: 1,
@@ -61,19 +67,18 @@ export const {
shouldInstallInPlaceChanged,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zModelManagerState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zModelManagerState.parse(state);
},
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
},
};

View File

@@ -14,7 +14,13 @@ import type {
ReactFlowProps,
ReactFlowState,
} from '@xyflow/react';
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
import {
Background,
ReactFlow,
SelectionMode,
useStore as useReactFlowStore,
useUpdateNodeInternals,
} from '@xyflow/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
@@ -256,7 +262,7 @@ export const Flow = memo(() => {
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={null}
selectionMode={selectionMode}
selectionMode={selectionMode === 'full' ? SelectionMode.Full : SelectionMode.Partial}
elevateEdgesOnSelect
nodeDragThreshold={1}
noDragClassName={NO_DRAG_CLASS}

View File

@@ -13,12 +13,13 @@ import type {
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react';
import type { SliceConfig } from 'app/store/types';
import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit';
import {
addElement,
removeElement,
reparentElement,
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
import type { NodesState } from 'features/nodes/store/types';
import { type NodesState, zNodesState } from 'features/nodes/store/types';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
@@ -127,6 +128,7 @@ import {
import { atom, computed } from 'nanostores';
import type { MouseEvent } from 'react';
import type { UndoableOptions } from 'redux-undo';
import { assert } from 'tsafe';
import type { z } from 'zod';
import type { PendingConnection, Templates } from './types';
@@ -760,14 +762,6 @@ export const {
redo,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({});
export const $hasTemplates = computed($templates, (templates) => Object.keys(templates).length > 0);
@@ -938,9 +932,16 @@ const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
export const nodesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zNodesState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zNodesState.parse(state);
},
},
undoableConfig: {
reduxUndoOptions,

View File

@@ -1,7 +1,8 @@
import type { HandleType } from '@xyflow/react';
import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import z from 'zod';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
@@ -13,11 +14,13 @@ export type PendingConnection = {
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};
export type WorkflowMode = 'edit' | 'view';
export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: AnyEdge[];
formFieldInitialValues: Record<string, StatefulFieldValue>;
} & Omit<WorkflowV3, 'nodes' | 'edges' | 'is_published'>;
export const zWorkflowMode = z.enum(['edit', 'view']);
export type WorkflowMode = z.infer<typeof zWorkflowMode>;
export const zNodesState = z.object({
_version: z.literal(1),
nodes: z.array(zAnyNode),
edges: z.array(zAnyEdge),
formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
});
export type NodesState = z.infer<typeof zNodesState>;

View File

@@ -2,21 +2,29 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { WorkflowMode } from 'features/nodes/store/types';
import { type WorkflowMode, zWorkflowMode } from 'features/nodes/store/types';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { atom, computed } from 'nanostores';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
import {
type SQLiteDirection,
type WorkflowRecordOrderBy,
zSQLiteDirection,
zWorkflowRecordOrderBy,
} from 'services/api/types';
import z from 'zod';
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published';
const zWorkflowLibraryView = z.enum(['recent', 'yours', 'private', 'shared', 'defaults', 'published']);
export type WorkflowLibraryView = z.infer<typeof zWorkflowLibraryView>;
type WorkflowLibraryState = {
mode: WorkflowMode;
view: WorkflowLibraryView;
orderBy: WorkflowRecordOrderBy;
direction: SQLiteDirection;
searchTerm: string;
selectedTags: string[];
};
const zWorkflowLibraryState = z.object({
mode: zWorkflowMode,
view: zWorkflowLibraryView,
orderBy: zWorkflowRecordOrderBy,
direction: zSQLiteDirection,
searchTerm: z.string(),
selectedTags: z.array(z.string()),
});
type WorkflowLibraryState = z.infer<typeof zWorkflowLibraryState>;
const getInitialState = (): WorkflowLibraryState => ({
mode: 'view',
@@ -76,14 +84,12 @@ export const {
workflowLibraryViewChanged,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => state;
export const workflowLibrarySliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zWorkflowLibraryState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => zWorkflowLibraryState.parse(state),
},
};

View File

@@ -1,9 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import { SelectionMode } from '@xyflow/react';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { Selector } from 'react-redux';
import { assert } from 'tsafe';
import z from 'zod';
export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']);
@@ -12,23 +13,26 @@ export const zLayoutDirection = z.enum(['TB', 'LR']);
type LayoutDirection = z.infer<typeof zLayoutDirection>;
export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']);
type NodeAlignment = z.infer<typeof zNodeAlignment>;
const zSelectionMode = z.enum(['partial', 'full']);
export type WorkflowSettingsState = {
_version: 1;
shouldShowMinimapPanel: boolean;
layeringStrategy: LayeringStrategy;
nodeSpacing: number;
layerSpacing: number;
layoutDirection: LayoutDirection;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
nodeAlignment: NodeAlignment;
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectionMode: SelectionMode;
};
const zWorkflowSettingsState = z.object({
_version: z.literal(1),
shouldShowMinimapPanel: z.boolean(),
layeringStrategy: zLayeringStrategy,
nodeSpacing: z.number(),
layerSpacing: z.number(),
layoutDirection: zLayoutDirection,
shouldValidateGraph: z.boolean(),
shouldAnimateEdges: z.boolean(),
nodeAlignment: zNodeAlignment,
nodeOpacity: z.number(),
shouldSnapToGrid: z.boolean(),
shouldColorEdges: z.boolean(),
shouldShowEdgeLabels: z.boolean(),
selectionMode: zSelectionMode,
});
export type WorkflowSettingsState = z.infer<typeof zWorkflowSettingsState>;
const getInitialState = (): WorkflowSettingsState => ({
_version: 1,
@@ -44,7 +48,7 @@ const getInitialState = (): WorkflowSettingsState => ({
shouldColorEdges: true,
shouldShowEdgeLabels: false,
nodeOpacity: 1,
selectionMode: SelectionMode.Partial,
selectionMode: 'partial',
});
const slice = createSlice({
@@ -88,7 +92,7 @@ const slice = createSlice({
state.nodeAlignment = action.payload;
},
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
state.selectionMode = action.payload ? 'full' : 'partial';
},
},
});
@@ -109,19 +113,18 @@ export const {
selectionModeChanged,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zWorkflowSettingsState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zWorkflowSettingsState.parse(state);
},
},
};

View File

@@ -92,7 +92,7 @@ export const zMainModelBase = z.enum([
]);
type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
export const zModelType = z.enum([
'main',
'vae',
'lora',

View File

@@ -43,7 +43,7 @@ export const zNotesNodeData = z.object({
isOpen: z.boolean(),
notes: z.string(),
});
const _zCurrentImageNodeData = z.object({
const zCurrentImageNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('current_image'),
label: z.string(),
@@ -52,12 +52,35 @@ const _zCurrentImageNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
type CurrentImageNodeData = z.infer<typeof _zCurrentImageNodeData>;
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
export type NotesNode = Node<NotesNodeData, 'notes'>;
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
export type AnyNode = InvocationNode | NotesNode | CurrentImageNode;
const zInvocationNodeValidationSchema = z.looseObject({
type: z.literal('invocation'),
data: zInvocationNodeData,
});
const zInvocationNode = z.custom<Node<InvocationNodeData, 'invocation'>>(
(val) => zInvocationNodeValidationSchema.safeParse(val).success
);
export type InvocationNode = z.infer<typeof zInvocationNode>;
const zNotesNodeValidationSchema = z.looseObject({
type: z.literal('notes'),
data: zNotesNodeData,
});
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
export type NotesNode = z.infer<typeof zNotesNode>;
const zCurrentImageNodeValidationSchema = z.looseObject({
type: z.literal('current_image'),
data: zCurrentImageNodeData,
});
const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
(val) => zCurrentImageNodeValidationSchema.safeParse(val).success
);
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
export type AnyNode = z.infer<typeof zAnyNode>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
@@ -83,13 +106,29 @@ export type NodeExecutionState = z.infer<typeof _zNodeExecutionState>;
// #endregion
// #region Edges
const _zInvocationNodeEdgeCollapsedData = z.object({
const zDefaultInvocationNodeEdgeValidationSchema = z.looseObject({
type: z.literal('default'),
});
const zDefaultInvocationNodeEdge = z.custom<Edge<Record<string, never>, 'default'>>(
(val) => zDefaultInvocationNodeEdgeValidationSchema.safeParse(val).success
);
export type DefaultInvocationNodeEdge = z.infer<typeof zDefaultInvocationNodeEdge>;
const zInvocationNodeEdgeCollapsedData = z.object({
count: z.number().int().min(1),
});
type InvocationNodeEdgeCollapsedData = z.infer<typeof _zInvocationNodeEdgeCollapsedData>;
export type DefaultInvocationNodeEdge = Edge<Record<string, never>, 'default'>;
export type CollapsedInvocationNodeEdge = Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>;
export type AnyEdge = DefaultInvocationNodeEdge | CollapsedInvocationNodeEdge;
const zInvocationNodeEdgeCollapsedValidationSchema = z.looseObject({
type: z.literal('default'),
data: zInvocationNodeEdgeCollapsedData,
});
type InvocationNodeEdgeCollapsedData = z.infer<typeof zInvocationNodeEdgeCollapsedData>;
const zCollapsedInvocationNodeEdge = z.custom<Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>>(
(val) => zInvocationNodeEdgeCollapsedValidationSchema.safeParse(val).success
);
export type CollapsedInvocationNodeEdge = z.infer<typeof zCollapsedInvocationNodeEdge>;
export const zAnyEdge = z.union([zDefaultInvocationNodeEdge, zCollapsedInvocationNodeEdge]);
export type AnyEdge = z.infer<typeof zAnyEdge>;
// #endregion
export const isBatchNodeType = (type: string) =>

View File

@@ -1,4 +1,5 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
@@ -6,13 +7,35 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types';
const selectTileControlNetModelConfig = createSelector(
selectModelConfigsQuery,
selectTileControlNetModel,
(modelConfigs, modelIdentifierField) => {
if (!modelConfigs.data) {
return null;
}
if (!modelIdentifierField) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key);
if (!modelConfig) {
return null;
}
if (!isControlNetModelConfig(modelConfig)) {
return null;
}
return modelConfig;
}
);
const ParamTileControlNetModel = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const tileControlNetModel = useAppSelector(selectTileControlNetModel);
const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig);
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetModels();

View File

@@ -1,21 +1,21 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) =>
createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => {
const { upscaleModel, scale } = upscale;
const { maxUpscaleDimension } = config;
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) {
if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) {
// When these are missing, another warning will be shown
return false;
}
const { width, height } = imageDTO;
const { width, height } = imageWithDims;
const maxPixels = maxUpscaleDimension ** 2;
const upscaledPixels = width * scale * height * scale;
@@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
return upscaledPixels > maxPixels;
});
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]);
export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]);
return useAppSelector(selectIsTooLargeToUpscale);
};

View File

@@ -2,24 +2,32 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
import type { ControlNetModelConfig, ImageDTO } from 'services/api/types';
import type { ControlNetModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import z from 'zod';
export interface UpscaleState {
_version: 1;
upscaleModel: ParameterSpandrelImageToImageModel | null;
upscaleInitialImage: ImageDTO | null;
structure: number;
creativity: number;
tileControlnetModel: ControlNetModelConfig | null;
scale: number;
postProcessingModel: ParameterSpandrelImageToImageModel | null;
tileSize: number;
tileOverlap: number;
}
const zUpscaleState = z.object({
_version: z.literal(2),
upscaleModel: zModelIdentifierField.nullable(),
upscaleInitialImage: zImageWithDims.nullable(),
structure: z.number(),
creativity: z.number(),
tileControlnetModel: zModelIdentifierField.nullable(),
scale: z.number(),
postProcessingModel: zModelIdentifierField.nullable(),
tileSize: z.number(),
tileOverlap: z.number(),
});
export type UpscaleState = z.infer<typeof zUpscaleState>;
const getInitialState = (): UpscaleState => ({
_version: 1,
_version: 2,
upscaleModel: null,
upscaleInitialImage: null,
structure: 0,
@@ -38,7 +46,7 @@ const slice = createSlice({
upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.upscaleModel = action.payload;
},
upscaleInitialImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
upscaleInitialImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
state.upscaleInitialImage = action.payload;
},
structureChanged: (state, action: PayloadAction<number>) => {
@@ -77,19 +85,30 @@ export const {
tileOverlapChanged,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const upscaleSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zUpscaleState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state._version = 2;
// Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims
if (state.upscaleInitialImage) {
const { image_name, width, height } = state.upscaleInitialImage;
state.upscaleInitialImage = {
image_name,
width,
height,
};
}
}
return zUpscaleState.parse(state);
},
},
};

View File

@@ -2,13 +2,15 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
interface QueueState {
listCursor: number | undefined;
listPriority: number | undefined;
selectedQueueItem: string | undefined;
resumeProcessorOnEnqueue: boolean;
}
const zQueueState = z.object({
listCursor: z.number().optional(),
listPriority: z.number().optional(),
selectedQueueItem: z.string().optional(),
resumeProcessorOnEnqueue: z.boolean(),
});
type QueueState = z.infer<typeof zQueueState>;
const getInitialState = (): QueueState => ({
listCursor: undefined,
@@ -38,6 +40,7 @@ export const { listCursorChanged, listPriorityChanged, listParamsReset } = slice
export const queueSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zQueueState,
getInitialState,
};

View File

@@ -1,6 +1,7 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -10,11 +11,13 @@ import { selectUpscaleInitialImage, upscaleInitialImageChanged } from 'features/
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const UpscaleInitialImage = () => {
const dispatch = useAppDispatch();
const imageDTO = useAppSelector(selectUpscaleInitialImage);
const upscaleInitialImage = useAppSelector(selectUpscaleInitialImage);
const imageDTO = useImageDTO(upscaleInitialImage?.image_name);
const dndTargetData = useMemo<SetUpscaleInitialImageDndTargetData>(
() => setUpscaleInitialImageDndTarget.getData(),
[]
@@ -26,7 +29,7 @@ export const UpscaleInitialImage = () => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
},
[dispatch]
);

View File

@@ -2,11 +2,21 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { atom } from 'nanostores';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { assert } from 'tsafe';
import z from 'zod';
import type { StylePresetState } from './types';
const zStylePresetState = z.object({
activeStylePresetId: z.string().nullable(),
searchTerm: z.string(),
viewMode: z.boolean(),
showPromptPreviews: z.boolean(),
});
type StylePresetState = z.infer<typeof zStylePresetState>;
const getInitialState = (): StylePresetState => ({
activeStylePresetId: null,
@@ -60,19 +70,18 @@ const slice = createSlice({
export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } =
slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const stylePresetSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zStylePresetState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zStylePresetState.parse(state);
},
},
};

View File

@@ -1,6 +0,0 @@
export type StylePresetState = {
activeStylePresetId: string | null;
searchTerm: string;
viewMode: boolean;
showPromptPreviews: boolean;
};

View File

@@ -32,7 +32,7 @@ export const { configChanged } = slice.actions;
export const configSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zConfigState,
schema: zConfigState,
getInitialState,
};

View File

@@ -5,9 +5,11 @@ import { zLogNamespace } from 'app/logging/logger';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { uniq } from 'es-toolkit/compat';
import { assert } from 'tsafe';
import type { Language, SystemState } from './types';
import { type Language, type SystemState, zSystemState } from './types';
const getInitialState = (): SystemState => ({
_version: 2,
@@ -92,23 +94,22 @@ export const {
setShouldHighlightFocusedRegions,
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return state;
};
export const systemSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zSystemState,
getInitialState,
persistConfig: {
migrate,
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return zSystemState.parse(state);
},
},
};

View File

@@ -1,4 +1,4 @@
import type { LogLevel, LogNamespace } from 'app/logging/logger';
import { zLogLevel, zLogNamespace } from 'app/logging/logger';
import { z } from 'zod';
const zLanguage = z.enum([
@@ -29,19 +29,20 @@ const zLanguage = z.enum([
export type Language = z.infer<typeof zLanguage>;
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
export interface SystemState {
_version: 2;
shouldConfirmOnDelete: boolean;
shouldAntialiasProgressImage: boolean;
shouldConfirmOnNewSession: boolean;
language: Language;
shouldUseNSFWChecker: boolean;
shouldUseWatermarker: boolean;
shouldEnableInformationalPopovers: boolean;
shouldEnableModelDescriptions: boolean;
logIsEnabled: boolean;
logLevel: LogLevel;
logNamespaces: LogNamespace[];
shouldShowInvocationProgressDetail: boolean;
shouldHighlightFocusedRegions: boolean;
}
export const zSystemState = z.object({
_version: z.literal(2),
shouldConfirmOnDelete: z.boolean(),
shouldAntialiasProgressImage: z.boolean(),
shouldConfirmOnNewSession: z.boolean(),
language: zLanguage,
shouldUseNSFWChecker: z.boolean(),
shouldUseWatermarker: z.boolean(),
shouldEnableInformationalPopovers: z.boolean(),
shouldEnableModelDescriptions: z.boolean(),
logIsEnabled: z.boolean(),
logLevel: zLogLevel,
logNamespaces: z.array(zLogNamespace),
shouldShowInvocationProgressDetail: z.boolean(),
shouldHighlightFocusedRegions: z.boolean(),
});
export type SystemState = z.infer<typeof zSystemState>;

View File

@@ -1,6 +1,7 @@
import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import {
@@ -37,7 +38,7 @@ export const UpscalingLaunchpadPanel = memo(() => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
},
[dispatch]
);

View File

@@ -2,6 +2,8 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import { getInitialUIState, type UIState, zUIState } from './uiTypes';
@@ -87,10 +89,11 @@ export const selectUiSlice = (state: RootState) => state.ui;
export const uiSliceConfig: SliceConfig<typeof slice> = {
slice,
zSchema: zUIState,
schema: zUIState,
getInitialState: getInitialUIState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}

View File

@@ -1,6 +1,9 @@
import type { Dimensions } from 'features/controlLayers/store/types';
import type { components, paths } from 'services/api/schema';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import type { JsonObject, SetRequired } from 'type-fest';
import z from 'zod';
export type S = components['schemas'];
@@ -33,10 +36,36 @@ export type InvocationJSONSchemaExtra = S['UIConfigBase'];
export type AppVersion = S['AppVersion'];
export type AppConfig = S['AppConfig'];
const zResourceOrigin = z.enum(['internal', 'external']);
type ResourceOrigin = z.infer<typeof zResourceOrigin>;
assert<Equals<ResourceOrigin, S['ResourceOrigin']>>();
const zImageCategory = z.enum(['general', 'mask', 'control', 'user', 'other']);
export type ImageCategory = z.infer<typeof zImageCategory>;
assert<Equals<ImageCategory, S['ImageCategory']>>();
// Images
export type ImageDTO = S['ImageDTO'];
const _zImageDTO = z.object({
image_name: z.string(),
image_url: z.string(),
thumbnail_url: z.string(),
image_origin: zResourceOrigin,
image_category: zImageCategory,
width: z.number().int().gt(0),
height: z.number().int().gt(0),
created_at: z.string(),
updated_at: z.string(),
deleted_at: z.string().nullish(),
is_intermediate: z.boolean(),
session_id: z.string().nullish(),
node_id: z.string().nullish(),
starred: z.boolean(),
has_workflow: z.boolean(),
board_id: z.string().nullish(),
});
export type ImageDTO = z.infer<typeof _zImageDTO>;
assert<Equals<ImageDTO, S['ImageDTO']>>();
export type BoardDTO = S['BoardDTO'];
export type ImageCategory = S['ImageCategory'];
export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_'];
// Models
@@ -298,8 +327,13 @@ export type ModelInstallStatus = S['InstallStatus'];
export type Graph = S['Graph'];
export type NonNullableGraph = SetRequired<Graph, 'nodes' | 'edges'>;
export type Batch = S['Batch'];
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
export type SQLiteDirection = S['SQLiteDirection'];
export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at']);
export type WorkflowRecordOrderBy = z.infer<typeof zWorkflowRecordOrderBy>;
assert<Equals<S['WorkflowRecordOrderBy'], WorkflowRecordOrderBy>>();
export const zSQLiteDirection = z.enum(['ASC', 'DESC']);
export type SQLiteDirection = z.infer<typeof zSQLiteDirection>;
assert<Equals<S['SQLiteDirection'], SQLiteDirection>>();
export type WorkflowRecordListItemWithThumbnailDTO = S['WorkflowRecordListItemWithThumbnailDTO'];
type KeysOfUnion<T> = T extends T ? keyof T : never;