feat(ui): create canvas metadata zod schema

- Add async zod refiner to `zImageWithDims` which fetches the image as part of validation
- Add `zServerValidatedModelIdentifierField`, a zod-refined version of `zModelIdentifierField` which fetches the model as part of validation
- Add `zCanvasMetadata` zod schema, which contains only canvas entities - no bbox, and no `isHidden` flags
This commit is contained in:
psychedelicious
2024-09-19 17:55:13 +10:00
parent 61091ac2fe
commit a28db7d496
2 changed files with 45 additions and 21 deletions

View File

@@ -1,4 +1,5 @@
import type { SerializableObject } from 'common/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
@@ -6,17 +7,33 @@ import {
zParameterNegativePrompt,
zParameterPositivePrompt,
} from 'features/parameters/types/parameterSchemas';
import { getImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { z } from 'zod';
const zId = z.string().min(1);
const zName = z.string().min(1).nullable();
const zImageWithDims = z.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
try {
await fetchModelConfigByIdentifier(modelIdentifier);
return true;
} catch {
return false;
}
});
const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTO(image_name);
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zBeginEndStepPct = z
@@ -116,9 +133,8 @@ const zCanvasImageState = z.object({
image: zImageWithDims,
});
export type CanvasImageState = z.infer<typeof zCanvasImageState>;
export const isCanvasImageState = (v: unknown): v is CanvasImageState => zCanvasImageState.safeParse(v).success;
const zCanvasObjectState = z.discriminatedUnion('type', [
const zCanvasObjectState = z.union([
zCanvasImageState,
zCanvasBrushLineState,
zCanvasEraserLineState,
@@ -129,7 +145,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
@@ -185,7 +201,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
const zControlNetConfig = z.object({
type: z.literal('controlnet'),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2,
@@ -194,7 +210,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
});
@@ -312,9 +328,17 @@ const zCanvasState = z.object({
optimalDimension: z.number().int().positive(),
}),
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const zCanvasMetadata = z.object({
inpaintMasks: z.array(zCanvasInpaintMaskState),
rasterLayers: z.array(zCanvasRasterLayerState),
controlLayers: z.array(zCanvasControlLayerState),
regionalGuidance: z.array(zCanvasRegionalGuidanceState),
referenceImages: z.array(zCanvasReferenceImageState),
});
export type CanvasMetadata = z.infer<typeof zCanvasMetadata>;
export type StageAttrs = {
x: Coordinate['x'];
y: Coordinate['y'];

View File

@@ -76,17 +76,17 @@ const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type:
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
// export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
// try {
// return await fetchModelConfig(identifier.key);
// } catch {
// try {
// return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
// } catch {
// throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
// }
// }
// };
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
try {
return await fetchModelConfig(identifier.key);
} catch {
try {
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
}
}
};
/**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.