diff --git a/invokeai/frontend/web/src/common/util/zodUtils.ts b/invokeai/frontend/web/src/common/util/zodUtils.ts new file mode 100644 index 0000000000..10506736e1 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/zodUtils.ts @@ -0,0 +1,10 @@ +import type { z } from 'zod'; + +/** + * Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type. + * @param schema The zod schema to create a type guard from. + * @returns A type guard function for the schema. + */ +export const buildZodTypeGuard = (schema: T) => { + return (val: unknown): val is z.infer => schema.safeParse(val).success; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 5fd71f8158..cbcd0dcb26 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -70,7 +70,9 @@ import type { T2IAdapterConfig, } from './types'; import { + DEFAULT_ASPECT_RATIO_CONFIG, getEntityIdentifier, + getInitialState, isChatGPT4oAspectRatioID, isFluxKontextAspectRatioID, isImagenAspectRatioID, @@ -93,55 +95,9 @@ import { initialT2IAdapter, } from './util'; -/** - * Gets a fresh canvas initial state with no references in memory to existing objects. - */ -const getInitialState = (): CanvasState => { - const initialInpaintMaskState = getInpaintMaskState(getPrefixedId('inpaint_mask')); - const initialState: CanvasState = { - _version: 3, - selectedEntityIdentifier: getEntityIdentifier(initialInpaintMaskState), - bookmarkedEntityIdentifier: getEntityIdentifier(initialInpaintMaskState), - rasterLayers: { - isHidden: false, - entities: [], - }, - controlLayers: { - isHidden: false, - entities: [], - }, - inpaintMasks: { - isHidden: false, - entities: [initialInpaintMaskState], - }, - regionalGuidance: { - isHidden: false, - entities: [], - }, - referenceImages: { entities: [] }, - bbox: { - rect: { x: 0, y: 0, width: 512, height: 512 }, - aspectRatio: { - id: '1:1', - value: 1, - isLocked: false, - }, - scaleMethod: 'auto', - scaledSize: { - width: 512, - height: 512, - }, - modelBase: 'sd-1', - }, - }; - return initialState; -}; - -const initialState = getInitialState(); - export const canvasSlice = createSlice({ name: 'canvas', - initialState, + initialState: getInitialState(), reducers: { // undoable canvas state //#region Raster layers @@ -1409,7 +1365,7 @@ export const canvasSlice = createSlice({ state.bbox.rect.width = width; state.bbox.rect.height = height; } else { - state.bbox.aspectRatio = deepClone(initialState.bbox.aspectRatio); + state.bbox.aspectRatio = deepClone(DEFAULT_ASPECT_RATIO_CONFIG); state.bbox.rect.width = optimalDimension; state.bbox.rect.height = optimalDimension; } @@ -1975,7 +1931,7 @@ const migrate = (state: any): any => { export const canvasPersistConfig: PersistConfig = { name: canvasSlice.name, - initialState, + initialState: getInitialState(), migrate, persistDenylist: [], }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts index 32114001e3..fe37d1891b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts @@ -1,3 +1,4 @@ +import type { Selector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; @@ -23,6 +24,9 @@ import { assert } from 'tsafe'; */ export const selectCanvasSlice = (state: RootState) => state.canvas.present; +export const createCanvasSelector = (selector: Selector) => + createSelector(selectCanvasSlice, selector); + /** * Selects the total canvas entity count: * - Regions @@ -33,7 +37,7 @@ export const selectCanvasSlice = (state: RootState) => state.canvas.present; * * All entities are counted, regardless of their state. */ -const selectEntityCountAll = createSelector(selectCanvasSlice, (canvas) => { +const selectEntityCountAll = createCanvasSelector((canvas) => { return ( canvas.regionalGuidance.entities.length + canvas.referenceImages.entities.length + @@ -45,24 +49,29 @@ const selectEntityCountAll = createSelector(selectCanvasSlice, (canvas) => { const isVisibleEntity = (entity: CanvasRenderableEntityState) => entity.isEnabled && entity.objects.length > 0; -export const selectActiveRasterLayerEntities = createSelector(selectCanvasSlice, (canvas) => - canvas.rasterLayers.entities.filter(isVisibleEntity) +export const selectRasterLayerEntities = createCanvasSelector((canvas) => canvas.rasterLayers.entities); +export const selectActiveRasterLayerEntities = createSelector(selectRasterLayerEntities, (entities) => + entities.filter(isVisibleEntity) ); -export const selectActiveControlLayerEntities = createSelector(selectCanvasSlice, (canvas) => - canvas.controlLayers.entities.filter(isVisibleEntity) +export const selectControlLayerEntities = createCanvasSelector((canvas) => canvas.controlLayers.entities); +export const selectActiveControlLayerEntities = createSelector(selectControlLayerEntities, (entities) => + entities.filter(isVisibleEntity) ); -export const selectActiveInpaintMaskEntities = createSelector(selectCanvasSlice, (canvas) => - canvas.inpaintMasks.entities.filter(isVisibleEntity) +export const selectInpaintMaskEntities = createCanvasSelector((canvas) => canvas.inpaintMasks.entities); +export const selectActiveInpaintMaskEntities = createSelector(selectInpaintMaskEntities, (entities) => + entities.filter(isVisibleEntity) ); -export const selectActiveRegionalGuidanceEntities = createSelector(selectCanvasSlice, (canvas) => - canvas.regionalGuidance.entities.filter(isVisibleEntity) +export const selectRegionalGuidanceEntities = createCanvasSelector((canvas) => canvas.regionalGuidance.entities); +export const selectActiveRegionalGuidanceEntities = createSelector(selectRegionalGuidanceEntities, (entities) => + entities.filter(isVisibleEntity) ); -export const selectActiveReferenceImageEntities = createSelector(selectCanvasSlice, (canvas) => - canvas.referenceImages.entities.filter((e) => e.isEnabled) +export const selectReferenceImageEntities = createCanvasSelector((canvas) => canvas.referenceImages.entities); +export const selectActiveReferenceImageEntities = createSelector(selectReferenceImageEntities, (entities) => + entities.filter((e) => e.isEnabled) ); /** @@ -192,14 +201,6 @@ export function selectEntityIdentifierBelowThisOne | undefined; } -export const selectRasterLayerEntities = createSelector(selectCanvasSlice, (canvas) => canvas.rasterLayers.entities); -export const selectControlLayerEntities = createSelector(selectCanvasSlice, (canvas) => canvas.controlLayers.entities); -export const selectInpaintMaskEntities = createSelector(selectCanvasSlice, (canvas) => canvas.inpaintMasks.entities); -export const selectRegionalGuidanceEntities = createSelector( - selectCanvasSlice, - (canvas) => canvas.regionalGuidance.entities -); - /** * Selected an entity from the canvas slice. If the entity is not found, an error is thrown. * @@ -218,7 +219,7 @@ export function selectEntityOrThrow( } export const selectEntityExists = (entityIdentifier: T) => { - return createSelector(selectCanvasSlice, (canvas) => Boolean(selectEntity(canvas, entityIdentifier))); + return createCanvasSelector((canvas) => Boolean(selectEntity(canvas, entityIdentifier))); }; /** @@ -299,7 +300,7 @@ export function selectRegionalGuidanceReferenceImage( return entity.referenceImages.find(({ id }) => id === referenceImageId); } -export const selectBbox = createSelector(selectCanvasSlice, (canvas) => canvas.bbox); +export const selectBbox = createCanvasSelector((canvas) => canvas.bbox); export const selectSelectedEntityIdentifier = createSelector( selectCanvasSlice, @@ -331,10 +332,10 @@ export const selectSelectedEntityFill = createSelector( } ); -const selectRasterLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.rasterLayers.isHidden); -const selectControlLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.controlLayers.isHidden); -const selectInpaintMasksIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.inpaintMasks.isHidden); -const selectRegionalGuidanceIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.regionalGuidance.isHidden); +const selectRasterLayersIsHidden = createCanvasSelector((canvas) => canvas.rasterLayers.isHidden); +const selectControlLayersIsHidden = createCanvasSelector((canvas) => canvas.controlLayers.isHidden); +const selectInpaintMasksIsHidden = createCanvasSelector((canvas) => canvas.inpaintMasks.isHidden); +const selectRegionalGuidanceIsHidden = createCanvasSelector((canvas) => canvas.regionalGuidance.isHidden); /** * Returns the hidden selector for the given entity type. @@ -372,7 +373,7 @@ export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier) * Other entities are considered empty if they have no objects. */ export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier) => { - return createSelector(selectCanvasSlice, (canvas) => { + return createCanvasSelector((canvas) => { const entity = selectEntity(canvas, entityIdentifier); if (!entity) { @@ -385,10 +386,10 @@ export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier) }); }; -export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width); -export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height); -export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id); -export const selectAspectRatioValue = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.value); +export const selectWidth = createCanvasSelector((canvas) => canvas.bbox.rect.width); +export const selectHeight = createCanvasSelector((canvas) => canvas.bbox.rect.height); +export const selectAspectRatioID = createCanvasSelector((canvas) => canvas.bbox.aspectRatio.id); +export const selectAspectRatioValue = createCanvasSelector((canvas) => canvas.bbox.aspectRatio.value); export const selectScaledSize = createSelector(selectBbox, (bbox) => bbox.scaledSize); export const selectScaleMethod = createSelector(selectBbox, (bbox) => bbox.scaleMethod); export const selectBboxRect = createSelector(selectBbox, (bbox) => bbox.rect); @@ -407,3 +408,21 @@ export const selectCanvasMetadata = createSelector( return { canvas_v2_metadata }; } ); + +export const selectIsSessionStarted = createCanvasSelector(({ isSessionStarted }) => isSessionStarted); +export const selectIsCanvasEmpty = createCanvasSelector( + ({ controlLayers, inpaintMasks, rasterLayers, regionalGuidance }) => { + // Check it all manually - could use lodash isEqual, but this selector will be called very often! + // Also note - we do not care about ref images, as they are technically not part of canvas + return ( + controlLayers.entities.length === 0 && + controlLayers.isHidden === false && + inpaintMasks.entities.length === 0 && + inpaintMasks.isHidden === false && + rasterLayers.entities.length === 0 && + rasterLayers.isHidden === false && + regionalGuidance.entities.length === 0 && + regionalGuidance.isHidden === false + ); + } +); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 9d0f79b8a8..a4aee48086 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -1,3 +1,4 @@ +import { deepClone } from 'common/util/deepClone'; import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types'; import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers'; import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; @@ -438,51 +439,87 @@ export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer; export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success; +const zAspectRatioConfig = z.object({ + id: zAspectRatioID, + value: z.number().gt(0), + isLocked: z.boolean(), +}); +type AspectRatioConfig = z.infer; + +export const DEFAULT_ASPECT_RATIO_CONFIG: AspectRatioConfig = { + id: '1:1', + value: 1, + isLocked: false, +}; + +const zBboxState = z.object({ + rect: z.object({ + x: z.number().int(), + y: z.number().int(), + width: zParameterImageDimension, + height: zParameterImageDimension, + }), + aspectRatio: zAspectRatioConfig, + scaledSize: z.object({ + width: zParameterImageDimension, + height: zParameterImageDimension, + }), + scaleMethod: zBoundingBoxScaleMethod, + modelBase: zMainModelBase, +}); const zCanvasState = z.object({ - _version: z.literal(3), - selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(), - bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(), - inpaintMasks: z.object({ - isHidden: z.boolean(), - entities: z.array(zCanvasInpaintMaskState), - }), - rasterLayers: z.object({ - isHidden: z.boolean(), - entities: z.array(zCanvasRasterLayerState), - }), - controlLayers: z.object({ - isHidden: z.boolean(), - entities: z.array(zCanvasControlLayerState), - }), - regionalGuidance: z.object({ - isHidden: z.boolean(), - entities: z.array(zCanvasRegionalGuidanceState), - }), - referenceImages: z.object({ - entities: z.array(zCanvasReferenceImageState), - }), - bbox: z.object({ - rect: z.object({ - x: z.number().int(), - y: z.number().int(), - width: zParameterImageDimension, - height: zParameterImageDimension, - }), - aspectRatio: z.object({ - id: zAspectRatioID, - value: z.number().gt(0), - isLocked: z.boolean(), - }), - scaledSize: z.object({ - width: zParameterImageDimension, - height: zParameterImageDimension, - }), - scaleMethod: zBoundingBoxScaleMethod, - modelBase: zMainModelBase, + _version: z.literal(3).default(3), + isSessionStarted: z.boolean().default(false), + selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null), + bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null), + inpaintMasks: z + .object({ + isHidden: z.boolean(), + entities: z.array(zCanvasInpaintMaskState), + }) + .default({ isHidden: false, entities: [] }), + rasterLayers: z + .object({ + isHidden: z.boolean(), + entities: z.array(zCanvasRasterLayerState), + }) + .default({ isHidden: false, entities: [] }), + controlLayers: z + .object({ + isHidden: z.boolean(), + entities: z.array(zCanvasControlLayerState), + }) + .default({ isHidden: false, entities: [] }), + regionalGuidance: z + .object({ + isHidden: z.boolean(), + entities: z.array(zCanvasRegionalGuidanceState), + }) + .default({ isHidden: false, entities: [] }), + referenceImages: z + .object({ + entities: z.array(zCanvasReferenceImageState), + }) + .default({ entities: [] }), + bbox: zBboxState.default({ + rect: { x: 0, y: 0, width: 512, height: 512 }, + aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG, + scaleMethod: 'auto', + scaledSize: { + width: 512, + height: 512, + }, + modelBase: 'sd-1', }), }); export type CanvasState = z.infer; +/** + * Gets a fresh canvas initial state with no references in memory to existing objects. + */ +const CANVAS_INITIAL_STATE = zCanvasState.parse({}); +export const getInitialState = () => deepClone(CANVAS_INITIAL_STATE); + export const zCanvasMetadata = z.object({ inpaintMasks: z.array(zCanvasInpaintMaskState), rasterLayers: z.array(zCanvasRasterLayerState), diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts b/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts index 9f95039c4f..741c415f38 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts +++ b/invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts @@ -1,11 +1,12 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; +import { buildZodTypeGuard } from 'common/util/zodUtils'; import { z } from 'zod'; const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']); +export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour); export type SeedBehaviour = z.infer; -export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success; export interface DynamicPromptsState { _version: 1; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 1076100947..acc02785b9 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,5 +1,6 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; +import { buildZodTypeGuard } from 'common/util/zodUtils'; import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common'; import { z } from 'zod'; @@ -15,21 +16,12 @@ import { z } from 'zod'; * simply be the zod schema's safeParse function */ -/** - * Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type. - * @param schema The zod schema to create a type guard from. - * @returns A type guard function for the schema. - */ -export const buildTypeGuard = (schema: T) => { - return (val: unknown): val is z.infer => schema.safeParse(val).success; -}; - /** * Helper to create a zod schema and a type guard from it. * @param schema The zod schema to create a type guard from. * @returns A tuple containing the zod schema and the type guard function. */ -const buildParameter = (schema: T) => [schema, buildTypeGuard(schema)] as const; +export const buildParameter = (schema: T) => [schema, buildZodTypeGuard(schema)] as const; // #region Positive prompt export const [zParameterPositivePrompt, isParameterPositivePrompt] = buildParameter(z.string());