feat(ui): use zod to define canvas state

By modeling canvas state as a zod schema vs a Typescript type, we get a runtime validator that can be used for metadata recall.
This commit is contained in:
psychedelicious
2024-09-18 16:40:41 +10:00
committed by Kent Keirsey
parent 4dc194510c
commit 95675c0545
7 changed files with 147 additions and 110 deletions

View File

@@ -24,8 +24,7 @@ import { getScaledBoundingBoxDimensions } from 'features/controlLayers/util/getS
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
import { ASPECT_RATIO_MAP, initialAspectRatioState } from 'features/parameters/components/Bbox/constants';
import type { AspectRatioID } from 'features/parameters/components/Bbox/types';
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { merge, omit } from 'lodash-es';
@@ -35,6 +34,7 @@ import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterM
import { assert } from 'tsafe';
import type {
AspectRatioID,
BoundingBoxScaleMethod,
CanvasControlLayerState,
CanvasEntityIdentifier,
@@ -92,7 +92,11 @@ const getInitialState = (): CanvasState => {
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
optimalDimension: 512,
aspectRatio: deepClone(initialAspectRatioState),
aspectRatio: {
id: '1:1',
value: 1,
isLocked: false,
},
scaleMethod: 'auto',
scaledSize: {
width: 512,
@@ -739,7 +743,7 @@ export const canvasSlice = createSlice({
state.bbox.rect.width = width;
state.bbox.rect.height = height;
} else {
state.bbox.aspectRatio = deepClone(initialAspectRatioState);
state.bbox.aspectRatio = deepClone(initialState.bbox.aspectRatio);
state.bbox.rect.width = state.bbox.optimalDimension;
state.bbox.rect.height = state.bbox.optimalDimension;
}

View File

@@ -19,26 +19,66 @@ import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { z } from 'zod';
import type { CLIPVisionModelV2, ControlModeV2, IPMethodV2 } from './types';
import type {
CanvasEntityIdentifier,
CLIPVisionModelV2,
ControlModeV2,
IPMethodV2,
zCanvasEntityIdentifer,
} from './types';
describe('Control Adapter Types', () => {
test('ProcessorType', () => {
test('FilterType', () => {
// FilterType is a union of all filter types. FilterConfig is inferred from a zod union. zod does not support
// extracting a specific field from a union, so FilterType is defined separately. To ensure that FilterType is
// consistent with FilterConfig['type'], we compare the two types.
assert<Equals<FilterConfig['type'], FilterType>>();
});
test('IP Adapter Method', () => {
// This ensures the manually defined IPMethodV2 type is consistent with the type we get from the API.
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
});
test('CLIP Vision Model', () => {
// This ensures the manually defined CLIPVisionModelV2 type is consistent with the type we get from the API.
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
});
test('Control Mode', () => {
// This ensures the manually defined ControlModeV2 type is consistent with the type we get from the API.
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
});
test('DepthAnything Model Size', () => {
// This ensures the manually defined DepthAnythingModelSize type is consistent with the type we get from the API.
assert<Equals<NonNullable<Invocation<'depth_anything_depth_estimation'>['model_size']>, DepthAnythingModelSize>>();
});
test('Processor Configs', () => {
// Types derived from OpenAPI
type _CannyEdgeDetectionFilterConfig = Required<
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
>;
type _ColorMapFilterConfig = Required<Pick<Invocation<'color_map'>, 'type' | 'tile_size'>>;
type _ContentShuffleFilterConfig = Required<Pick<Invocation<'content_shuffle'>, 'type' | 'scale_factor'>>;
type _DepthAnythingFilterConfig = Required<
Pick<Invocation<'depth_anything_depth_estimation'>, 'type' | 'model_size'>
>;
type _HEDEdgeDetectionFilterConfig = Required<Pick<Invocation<'hed_edge_detection'>, 'type' | 'scribble'>>;
type _LineartAnimeEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_anime_edge_detection'>, 'type'>>;
type _LineartEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_edge_detection'>, 'type' | 'coarse'>>;
type _MediaPipeFaceDetectionFilterConfig = Required<
Pick<Invocation<'mediapipe_face_detection'>, 'type' | 'max_faces' | 'min_confidence'>
>;
type _MLSDDetectionFilterConfig = Required<
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
>;
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
type _DWOpenposeDetectionFilterConfig = Required<
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
type _PiDiNetEdgeDetectionFilterConfig = Required<
Pick<Invocation<'pidi_edge_detection'>, 'type' | 'quantize_edges' | 'scribble'>
>;
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
@@ -54,28 +94,9 @@ describe('Control Adapter Types', () => {
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();
});
test('CanvasEntityIdentifier', () => {
// The generic type `CanvasEntityIdentifier` is defined manually, but it must be equal to the inferred type from
// the zod schema.
assert<Equals<CanvasEntityIdentifier, z.infer<typeof zCanvasEntityIdentifer>>>();
});
});
// Types derived from OpenAPI
type _CannyEdgeDetectionFilterConfig = Required<
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
>;
type _ColorMapFilterConfig = Required<Pick<Invocation<'color_map'>, 'type' | 'tile_size'>>;
type _ContentShuffleFilterConfig = Required<Pick<Invocation<'content_shuffle'>, 'type' | 'scale_factor'>>;
type _DepthAnythingFilterConfig = Required<Pick<Invocation<'depth_anything_depth_estimation'>, 'type' | 'model_size'>>;
type _HEDEdgeDetectionFilterConfig = Required<Pick<Invocation<'hed_edge_detection'>, 'type' | 'scribble'>>;
type _LineartAnimeEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_anime_edge_detection'>, 'type'>>;
type _LineartEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_edge_detection'>, 'type' | 'coarse'>>;
type _MediaPipeFaceDetectionFilterConfig = Required<
Pick<Invocation<'mediapipe_face_detection'>, 'type' | 'max_faces' | 'min_confidence'>
>;
type _MLSDDetectionFilterConfig = Required<
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
>;
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
type _DWOpenposeDetectionFilterConfig = Required<
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
type _PiDiNetEdgeDetectionFilterConfig = Required<
Pick<Invocation<'pidi_edge_detection'>, 'type' | 'quantize_edges' | 'scribble'>
>;

View File

@@ -1,8 +1,11 @@
import type { SerializableObject } from 'common/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { AspectRatioState } from 'features/parameters/components/Bbox/types';
import type { ParameterHeight, ParameterLoRAModel, ParameterWidth } from 'features/parameters/types/parameterSchemas';
import { zParameterNegativePrompt, zParameterPositivePrompt } from 'features/parameters/types/parameterSchemas';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
zParameterImageDimension,
zParameterNegativePrompt,
zParameterPositivePrompt,
} from 'features/parameters/types/parameterSchemas';
import type { ImageDTO } from 'services/api/types';
import { z } from 'zod';
@@ -217,20 +220,36 @@ export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
zBoundingBoxScaleMethod.safeParse(v).success;
export type CanvasEntityState =
| CanvasRasterLayerState
| CanvasControlLayerState
| CanvasRegionalGuidanceState
| CanvasInpaintMaskState
| CanvasReferenceImageState;
const zCanvasEntityState = z.discriminatedUnion('type', [
zCanvasRasterLayerState,
zCanvasControlLayerState,
zCanvasRegionalGuidanceState,
zCanvasInpaintMaskState,
zCanvasReferenceImageState,
]);
export type CanvasEntityState = z.infer<typeof zCanvasEntityState>;
export type CanvasRenderableEntityState =
| CanvasRasterLayerState
| CanvasControlLayerState
| CanvasRegionalGuidanceState
| CanvasInpaintMaskState;
const zCanvasRenderableEntityState = z.discriminatedUnion('type', [
zCanvasRasterLayerState,
zCanvasControlLayerState,
zCanvasRegionalGuidanceState,
zCanvasInpaintMaskState,
]);
export type CanvasRenderableEntityState = z.infer<typeof zCanvasRenderableEntityState>;
export type CanvasEntityType = CanvasEntityState['type'];
const zCanvasEntityType = z.union([
zCanvasRasterLayerState.shape.type,
zCanvasControlLayerState.shape.type,
zCanvasRegionalGuidanceState.shape.type,
zCanvasInpaintMaskState.shape.type,
zCanvasReferenceImageState.shape.type,
]);
export type CanvasEntityType = z.infer<typeof zCanvasEntityType>;
export const zCanvasEntityIdentifer = z.object({
id: zId,
type: zCanvasEntityType,
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type LoRA = {
@@ -246,45 +265,55 @@ export type StagingAreaImage = {
offsetY: number;
};
export type CanvasState = {
_version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null;
bookmarkedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMasks: {
isHidden: boolean;
entities: CanvasInpaintMaskState[];
};
rasterLayers: {
isHidden: boolean;
entities: CanvasRasterLayerState[];
};
controlLayers: {
isHidden: boolean;
entities: CanvasControlLayerState[];
};
regionalGuidance: {
isHidden: boolean;
entities: CanvasRegionalGuidanceState[];
};
referenceImages: {
entities: CanvasReferenceImageState[];
};
bbox: {
rect: {
x: number;
y: number;
width: ParameterWidth;
height: ParameterHeight;
};
aspectRatio: AspectRatioState;
scaledSize: {
width: ParameterWidth;
height: ParameterHeight;
};
scaleMethod: BoundingBoxScaleMethod;
optimalDimension: number;
};
};
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
const zCanvasState = z.object({
_version: z.literal(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
inpaintMasks: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasInpaintMaskState),
}),
rasterLayers: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRasterLayerState),
}),
controlLayers: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasControlLayerState),
}),
regionalGuidance: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
}),
referenceImages: z.object({
entities: z.array(zCanvasReferenceImageState),
}),
bbox: z.object({
rect: z.object({
x: z.number().int(),
y: z.number().int(),
width: zParameterImageDimension,
height: zParameterImageDimension,
}),
aspectRatio: z.object({
id: zAspectRatioID,
value: z.number().gt(0),
isLocked: z.boolean(),
}),
scaledSize: z.object({
width: zParameterImageDimension,
height: zParameterImageDimension,
}),
scaleMethod: zBoundingBoxScaleMethod,
optimalDimension: z.number().int().positive(),
}),
});
export type CanvasState = z.infer<typeof zCanvasState>;
export type StageAttrs = {
x: Coordinate['x'];

View File

@@ -5,8 +5,8 @@ import type { SingleValue } from 'chakra-react-select';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
import { isAspectRatioID } from 'features/controlLayers/store/types';
import { ASPECT_RATIO_OPTIONS } from 'features/parameters/components/Bbox/constants';
import { isAspectRatioID } from 'features/parameters/components/Bbox/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';

View File

@@ -1,6 +1,5 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { AspectRatioID, AspectRatioState } from './types';
import type { AspectRatioID } from 'features/controlLayers/store/types';
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
{ label: 'Free' as const, value: 'Free' },
@@ -22,9 +21,3 @@ export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: n
'2:3': { ratio: 2 / 3, inverseID: '3:2' },
'9:16': { ratio: 9 / 16, inverseID: '16:9' },
};
export const initialAspectRatioState: AspectRatioState = {
id: '1:1',
value: 1,
isLocked: false,
};

View File

@@ -1,11 +0,0 @@
import { z } from 'zod';
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
export type AspectRatioState = {
id: AspectRatioID;
value: number;
isLocked: boolean;
};

View File

@@ -82,18 +82,19 @@ export const isParameterSeed = (val: unknown): val is ParameterSeed => zParamete
// #endregion
// #region Width
const zParameterWidth = z
export const zParameterImageDimension = z
.number()
.min(64)
.transform((val) => roundToMultiple(val, 8));
export type ParameterWidth = z.infer<typeof zParameterWidth>;
export const isParameterWidth = (val: unknown): val is ParameterWidth => zParameterWidth.safeParse(val).success;
export type ParameterWidth = z.infer<typeof zParameterImageDimension>;
export const isParameterWidth = (val: unknown): val is ParameterWidth =>
zParameterImageDimension.safeParse(val).success;
// #endregion
// #region Height
const zParameterHeight = zParameterWidth;
export type ParameterHeight = z.infer<typeof zParameterHeight>;
export const isParameterHeight = (val: unknown): val is ParameterHeight => zParameterHeight.safeParse(val).success;
export type ParameterHeight = z.infer<typeof zParameterImageDimension>;
export const isParameterHeight = (val: unknown): val is ParameterHeight =>
zParameterImageDimension.safeParse(val).success;
// #endregion
// #region Model