mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): create canvas metadata zod schema
- Add async zod refiner to `zImageWithDims` which fetches the image as part of validation - Add `zServerValidatedModelIdentifierField`, a zod-refined version of `zModelIdentifierField` which fetches the model as part of validation - Add `zCanvasMetadata` zod schema, which contains only canvas entities - no bbox, and no `isHidden` flags
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
@@ -6,17 +7,33 @@ import {
|
||||
zParameterNegativePrompt,
|
||||
zParameterPositivePrompt,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zId = z.string().min(1);
|
||||
const zName = z.string().min(1).nullable();
|
||||
|
||||
const zImageWithDims = z.object({
|
||||
image_name: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
height: z.number().int().positive(),
|
||||
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
|
||||
try {
|
||||
await fetchModelConfigByIdentifier(modelIdentifier);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
|
||||
const zImageWithDims = z
|
||||
.object({
|
||||
image_name: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
height: z.number().int().positive(),
|
||||
})
|
||||
.refine(async (v) => {
|
||||
const { image_name } = v;
|
||||
const imageDTO = await getImageDTO(image_name);
|
||||
return imageDTO !== null;
|
||||
});
|
||||
export type ImageWithDims = z.infer<typeof zImageWithDims>;
|
||||
|
||||
const zBeginEndStepPct = z
|
||||
@@ -116,9 +133,8 @@ const zCanvasImageState = z.object({
|
||||
image: zImageWithDims,
|
||||
});
|
||||
export type CanvasImageState = z.infer<typeof zCanvasImageState>;
|
||||
export const isCanvasImageState = (v: unknown): v is CanvasImageState => zCanvasImageState.safeParse(v).success;
|
||||
|
||||
const zCanvasObjectState = z.discriminatedUnion('type', [
|
||||
const zCanvasObjectState = z.union([
|
||||
zCanvasImageState,
|
||||
zCanvasBrushLineState,
|
||||
zCanvasEraserLineState,
|
||||
@@ -129,7 +145,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
|
||||
const zIPAdapterConfig = z.object({
|
||||
type: z.literal('ip_adapter'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
method: zIPMethodV2,
|
||||
@@ -185,7 +201,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
|
||||
|
||||
const zControlNetConfig = z.object({
|
||||
type: z.literal('controlnet'),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
controlMode: zControlModeV2,
|
||||
@@ -194,7 +210,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
|
||||
|
||||
const zT2IAdapterConfig = z.object({
|
||||
type: z.literal('t2i_adapter'),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
});
|
||||
@@ -312,9 +328,17 @@ const zCanvasState = z.object({
|
||||
optimalDimension: z.number().int().positive(),
|
||||
}),
|
||||
});
|
||||
|
||||
export type CanvasState = z.infer<typeof zCanvasState>;
|
||||
|
||||
export const zCanvasMetadata = z.object({
|
||||
inpaintMasks: z.array(zCanvasInpaintMaskState),
|
||||
rasterLayers: z.array(zCanvasRasterLayerState),
|
||||
controlLayers: z.array(zCanvasControlLayerState),
|
||||
regionalGuidance: z.array(zCanvasRegionalGuidanceState),
|
||||
referenceImages: z.array(zCanvasReferenceImageState),
|
||||
});
|
||||
export type CanvasMetadata = z.infer<typeof zCanvasMetadata>;
|
||||
|
||||
export type StageAttrs = {
|
||||
x: Coordinate['x'];
|
||||
y: Coordinate['y'];
|
||||
|
||||
@@ -76,17 +76,17 @@ const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type:
|
||||
* @returns A promise that resolves to the model config.
|
||||
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
|
||||
*/
|
||||
// export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
|
||||
// try {
|
||||
// return await fetchModelConfig(identifier.key);
|
||||
// } catch {
|
||||
// try {
|
||||
// return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
|
||||
// } catch {
|
||||
// throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
|
||||
try {
|
||||
return await fetchModelConfig(identifier.key);
|
||||
} catch {
|
||||
try {
|
||||
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
|
||||
|
||||
Reference in New Issue
Block a user