mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): use zod to define canvas state
By modeling canvas state as a zod schema vs a Typescript type, we get a runtime validator that can be used for metadata recall.
This commit is contained in:
committed by
Kent Keirsey
parent
4dc194510c
commit
95675c0545
@@ -24,8 +24,7 @@ import { getScaledBoundingBoxDimensions } from 'features/controlLayers/util/getS
|
||||
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
|
||||
import { ASPECT_RATIO_MAP, initialAspectRatioState } from 'features/parameters/components/Bbox/constants';
|
||||
import type { AspectRatioID } from 'features/parameters/components/Bbox/types';
|
||||
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
@@ -35,6 +34,7 @@ import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterM
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type {
|
||||
AspectRatioID,
|
||||
BoundingBoxScaleMethod,
|
||||
CanvasControlLayerState,
|
||||
CanvasEntityIdentifier,
|
||||
@@ -92,7 +92,11 @@ const getInitialState = (): CanvasState => {
|
||||
bbox: {
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
optimalDimension: 512,
|
||||
aspectRatio: deepClone(initialAspectRatioState),
|
||||
aspectRatio: {
|
||||
id: '1:1',
|
||||
value: 1,
|
||||
isLocked: false,
|
||||
},
|
||||
scaleMethod: 'auto',
|
||||
scaledSize: {
|
||||
width: 512,
|
||||
@@ -739,7 +743,7 @@ export const canvasSlice = createSlice({
|
||||
state.bbox.rect.width = width;
|
||||
state.bbox.rect.height = height;
|
||||
} else {
|
||||
state.bbox.aspectRatio = deepClone(initialAspectRatioState);
|
||||
state.bbox.aspectRatio = deepClone(initialState.bbox.aspectRatio);
|
||||
state.bbox.rect.width = state.bbox.optimalDimension;
|
||||
state.bbox.rect.height = state.bbox.optimalDimension;
|
||||
}
|
||||
|
||||
@@ -19,26 +19,66 @@ import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { CLIPVisionModelV2, ControlModeV2, IPMethodV2 } from './types';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CLIPVisionModelV2,
|
||||
ControlModeV2,
|
||||
IPMethodV2,
|
||||
zCanvasEntityIdentifer,
|
||||
} from './types';
|
||||
|
||||
describe('Control Adapter Types', () => {
|
||||
test('ProcessorType', () => {
|
||||
test('FilterType', () => {
|
||||
// FilterType is a union of all filter types. FilterConfig is inferred from a zod union. zod does not support
|
||||
// extracting a specific field from a union, so FilterType is defined separately. To ensure that FilterType is
|
||||
// consistent with FilterConfig['type'], we compare the two types.
|
||||
assert<Equals<FilterConfig['type'], FilterType>>();
|
||||
});
|
||||
test('IP Adapter Method', () => {
|
||||
// This ensures the manually defined IPMethodV2 type is consistent with the type we get from the API.
|
||||
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
|
||||
});
|
||||
test('CLIP Vision Model', () => {
|
||||
// This ensures the manually defined CLIPVisionModelV2 type is consistent with the type we get from the API.
|
||||
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
|
||||
});
|
||||
test('Control Mode', () => {
|
||||
// This ensures the manually defined ControlModeV2 type is consistent with the type we get from the API.
|
||||
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
|
||||
});
|
||||
test('DepthAnything Model Size', () => {
|
||||
// This ensures the manually defined DepthAnythingModelSize type is consistent with the type we get from the API.
|
||||
assert<Equals<NonNullable<Invocation<'depth_anything_depth_estimation'>['model_size']>, DepthAnythingModelSize>>();
|
||||
});
|
||||
test('Processor Configs', () => {
|
||||
// Types derived from OpenAPI
|
||||
type _CannyEdgeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
|
||||
>;
|
||||
type _ColorMapFilterConfig = Required<Pick<Invocation<'color_map'>, 'type' | 'tile_size'>>;
|
||||
type _ContentShuffleFilterConfig = Required<Pick<Invocation<'content_shuffle'>, 'type' | 'scale_factor'>>;
|
||||
type _DepthAnythingFilterConfig = Required<
|
||||
Pick<Invocation<'depth_anything_depth_estimation'>, 'type' | 'model_size'>
|
||||
>;
|
||||
type _HEDEdgeDetectionFilterConfig = Required<Pick<Invocation<'hed_edge_detection'>, 'type' | 'scribble'>>;
|
||||
type _LineartAnimeEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_anime_edge_detection'>, 'type'>>;
|
||||
type _LineartEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_edge_detection'>, 'type' | 'coarse'>>;
|
||||
type _MediaPipeFaceDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'mediapipe_face_detection'>, 'type' | 'max_faces' | 'min_confidence'>
|
||||
>;
|
||||
type _MLSDDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
|
||||
>;
|
||||
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
|
||||
type _DWOpenposeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
|
||||
>;
|
||||
type _PiDiNetEdgeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'pidi_edge_detection'>, 'type' | 'quantize_edges' | 'scribble'>
|
||||
>;
|
||||
|
||||
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
|
||||
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
|
||||
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
|
||||
@@ -54,28 +94,9 @@ describe('Control Adapter Types', () => {
|
||||
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
|
||||
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();
|
||||
});
|
||||
test('CanvasEntityIdentifier', () => {
|
||||
// The generic type `CanvasEntityIdentifier` is defined manually, but it must be equal to the inferred type from
|
||||
// the zod schema.
|
||||
assert<Equals<CanvasEntityIdentifier, z.infer<typeof zCanvasEntityIdentifer>>>();
|
||||
});
|
||||
});
|
||||
|
||||
// Types derived from OpenAPI
|
||||
type _CannyEdgeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
|
||||
>;
|
||||
type _ColorMapFilterConfig = Required<Pick<Invocation<'color_map'>, 'type' | 'tile_size'>>;
|
||||
type _ContentShuffleFilterConfig = Required<Pick<Invocation<'content_shuffle'>, 'type' | 'scale_factor'>>;
|
||||
type _DepthAnythingFilterConfig = Required<Pick<Invocation<'depth_anything_depth_estimation'>, 'type' | 'model_size'>>;
|
||||
type _HEDEdgeDetectionFilterConfig = Required<Pick<Invocation<'hed_edge_detection'>, 'type' | 'scribble'>>;
|
||||
type _LineartAnimeEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_anime_edge_detection'>, 'type'>>;
|
||||
type _LineartEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_edge_detection'>, 'type' | 'coarse'>>;
|
||||
type _MediaPipeFaceDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'mediapipe_face_detection'>, 'type' | 'max_faces' | 'min_confidence'>
|
||||
>;
|
||||
type _MLSDDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
|
||||
>;
|
||||
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
|
||||
type _DWOpenposeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
|
||||
>;
|
||||
type _PiDiNetEdgeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'pidi_edge_detection'>, 'type' | 'quantize_edges' | 'scribble'>
|
||||
>;
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { AspectRatioState } from 'features/parameters/components/Bbox/types';
|
||||
import type { ParameterHeight, ParameterLoRAModel, ParameterWidth } from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterNegativePrompt, zParameterPositivePrompt } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
zParameterImageDimension,
|
||||
zParameterNegativePrompt,
|
||||
zParameterPositivePrompt,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -217,20 +220,36 @@ export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
|
||||
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
|
||||
zBoundingBoxScaleMethod.safeParse(v).success;
|
||||
|
||||
export type CanvasEntityState =
|
||||
| CanvasRasterLayerState
|
||||
| CanvasControlLayerState
|
||||
| CanvasRegionalGuidanceState
|
||||
| CanvasInpaintMaskState
|
||||
| CanvasReferenceImageState;
|
||||
const zCanvasEntityState = z.discriminatedUnion('type', [
|
||||
zCanvasRasterLayerState,
|
||||
zCanvasControlLayerState,
|
||||
zCanvasRegionalGuidanceState,
|
||||
zCanvasInpaintMaskState,
|
||||
zCanvasReferenceImageState,
|
||||
]);
|
||||
export type CanvasEntityState = z.infer<typeof zCanvasEntityState>;
|
||||
|
||||
export type CanvasRenderableEntityState =
|
||||
| CanvasRasterLayerState
|
||||
| CanvasControlLayerState
|
||||
| CanvasRegionalGuidanceState
|
||||
| CanvasInpaintMaskState;
|
||||
const zCanvasRenderableEntityState = z.discriminatedUnion('type', [
|
||||
zCanvasRasterLayerState,
|
||||
zCanvasControlLayerState,
|
||||
zCanvasRegionalGuidanceState,
|
||||
zCanvasInpaintMaskState,
|
||||
]);
|
||||
export type CanvasRenderableEntityState = z.infer<typeof zCanvasRenderableEntityState>;
|
||||
|
||||
export type CanvasEntityType = CanvasEntityState['type'];
|
||||
const zCanvasEntityType = z.union([
|
||||
zCanvasRasterLayerState.shape.type,
|
||||
zCanvasControlLayerState.shape.type,
|
||||
zCanvasRegionalGuidanceState.shape.type,
|
||||
zCanvasInpaintMaskState.shape.type,
|
||||
zCanvasReferenceImageState.shape.type,
|
||||
]);
|
||||
export type CanvasEntityType = z.infer<typeof zCanvasEntityType>;
|
||||
|
||||
export const zCanvasEntityIdentifer = z.object({
|
||||
id: zId,
|
||||
type: zCanvasEntityType,
|
||||
});
|
||||
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
|
||||
|
||||
export type LoRA = {
|
||||
@@ -246,45 +265,55 @@ export type StagingAreaImage = {
|
||||
offsetY: number;
|
||||
};
|
||||
|
||||
export type CanvasState = {
|
||||
_version: 3;
|
||||
selectedEntityIdentifier: CanvasEntityIdentifier | null;
|
||||
bookmarkedEntityIdentifier: CanvasEntityIdentifier | null;
|
||||
inpaintMasks: {
|
||||
isHidden: boolean;
|
||||
entities: CanvasInpaintMaskState[];
|
||||
};
|
||||
rasterLayers: {
|
||||
isHidden: boolean;
|
||||
entities: CanvasRasterLayerState[];
|
||||
};
|
||||
controlLayers: {
|
||||
isHidden: boolean;
|
||||
entities: CanvasControlLayerState[];
|
||||
};
|
||||
regionalGuidance: {
|
||||
isHidden: boolean;
|
||||
entities: CanvasRegionalGuidanceState[];
|
||||
};
|
||||
referenceImages: {
|
||||
entities: CanvasReferenceImageState[];
|
||||
};
|
||||
bbox: {
|
||||
rect: {
|
||||
x: number;
|
||||
y: number;
|
||||
width: ParameterWidth;
|
||||
height: ParameterHeight;
|
||||
};
|
||||
aspectRatio: AspectRatioState;
|
||||
scaledSize: {
|
||||
width: ParameterWidth;
|
||||
height: ParameterHeight;
|
||||
};
|
||||
scaleMethod: BoundingBoxScaleMethod;
|
||||
optimalDimension: number;
|
||||
};
|
||||
};
|
||||
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
|
||||
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
|
||||
|
||||
const zCanvasState = z.object({
|
||||
_version: z.literal(3),
|
||||
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
|
||||
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
|
||||
inpaintMasks: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasInpaintMaskState),
|
||||
}),
|
||||
rasterLayers: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRasterLayerState),
|
||||
}),
|
||||
controlLayers: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasControlLayerState),
|
||||
}),
|
||||
regionalGuidance: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRegionalGuidanceState),
|
||||
}),
|
||||
referenceImages: z.object({
|
||||
entities: z.array(zCanvasReferenceImageState),
|
||||
}),
|
||||
bbox: z.object({
|
||||
rect: z.object({
|
||||
x: z.number().int(),
|
||||
y: z.number().int(),
|
||||
width: zParameterImageDimension,
|
||||
height: zParameterImageDimension,
|
||||
}),
|
||||
aspectRatio: z.object({
|
||||
id: zAspectRatioID,
|
||||
value: z.number().gt(0),
|
||||
isLocked: z.boolean(),
|
||||
}),
|
||||
scaledSize: z.object({
|
||||
width: zParameterImageDimension,
|
||||
height: zParameterImageDimension,
|
||||
}),
|
||||
scaleMethod: zBoundingBoxScaleMethod,
|
||||
optimalDimension: z.number().int().positive(),
|
||||
}),
|
||||
});
|
||||
|
||||
export type CanvasState = z.infer<typeof zCanvasState>;
|
||||
|
||||
export type StageAttrs = {
|
||||
x: Coordinate['x'];
|
||||
|
||||
@@ -5,8 +5,8 @@ import type { SingleValue } from 'chakra-react-select';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
|
||||
import { isAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { ASPECT_RATIO_OPTIONS } from 'features/parameters/components/Bbox/constants';
|
||||
import { isAspectRatioID } from 'features/parameters/components/Bbox/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
|
||||
import type { AspectRatioID, AspectRatioState } from './types';
|
||||
import type { AspectRatioID } from 'features/controlLayers/store/types';
|
||||
|
||||
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
|
||||
{ label: 'Free' as const, value: 'Free' },
|
||||
@@ -22,9 +21,3 @@ export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: n
|
||||
'2:3': { ratio: 2 / 3, inverseID: '3:2' },
|
||||
'9:16': { ratio: 9 / 16, inverseID: '16:9' },
|
||||
};
|
||||
|
||||
export const initialAspectRatioState: AspectRatioState = {
|
||||
id: '1:1',
|
||||
value: 1,
|
||||
isLocked: false,
|
||||
};
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
|
||||
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
|
||||
|
||||
export type AspectRatioState = {
|
||||
id: AspectRatioID;
|
||||
value: number;
|
||||
isLocked: boolean;
|
||||
};
|
||||
@@ -82,18 +82,19 @@ export const isParameterSeed = (val: unknown): val is ParameterSeed => zParamete
|
||||
// #endregion
|
||||
|
||||
// #region Width
|
||||
const zParameterWidth = z
|
||||
export const zParameterImageDimension = z
|
||||
.number()
|
||||
.min(64)
|
||||
.transform((val) => roundToMultiple(val, 8));
|
||||
export type ParameterWidth = z.infer<typeof zParameterWidth>;
|
||||
export const isParameterWidth = (val: unknown): val is ParameterWidth => zParameterWidth.safeParse(val).success;
|
||||
export type ParameterWidth = z.infer<typeof zParameterImageDimension>;
|
||||
export const isParameterWidth = (val: unknown): val is ParameterWidth =>
|
||||
zParameterImageDimension.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Height
|
||||
const zParameterHeight = zParameterWidth;
|
||||
export type ParameterHeight = z.infer<typeof zParameterHeight>;
|
||||
export const isParameterHeight = (val: unknown): val is ParameterHeight => zParameterHeight.safeParse(val).success;
|
||||
export type ParameterHeight = z.infer<typeof zParameterImageDimension>;
|
||||
export const isParameterHeight = (val: unknown): val is ParameterHeight =>
|
||||
zParameterImageDimension.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Model
|
||||
|
||||
Reference in New Issue
Block a user