diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 0d0e538ba4..0f2abc5f2a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -1,4 +1,5 @@ import type { SerializableObject } from 'common/types'; +import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import { @@ -6,17 +7,33 @@ import { zParameterNegativePrompt, zParameterPositivePrompt, } from 'features/parameters/types/parameterSchemas'; +import { getImageDTO } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import { z } from 'zod'; const zId = z.string().min(1); const zName = z.string().min(1).nullable(); -const zImageWithDims = z.object({ - image_name: z.string(), - width: z.number().int().positive(), - height: z.number().int().positive(), +const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => { + try { + await fetchModelConfigByIdentifier(modelIdentifier); + return true; + } catch { + return false; + } }); + +const zImageWithDims = z + .object({ + image_name: z.string(), + width: z.number().int().positive(), + height: z.number().int().positive(), + }) + .refine(async (v) => { + const { image_name } = v; + const imageDTO = await getImageDTO(image_name); + return imageDTO !== null; + }); export type ImageWithDims = z.infer; const zBeginEndStepPct = z @@ -116,9 +133,8 @@ const zCanvasImageState = z.object({ image: zImageWithDims, }); export type CanvasImageState = z.infer; -export const isCanvasImageState = (v: unknown): v is CanvasImageState => zCanvasImageState.safeParse(v).success; -const zCanvasObjectState = z.discriminatedUnion('type', [ +const zCanvasObjectState = z.union([ zCanvasImageState, zCanvasBrushLineState, zCanvasEraserLineState, @@ -129,7 +145,7 @@ export type CanvasObjectState = z.infer; const zIPAdapterConfig = z.object({ type: z.literal('ip_adapter'), image: zImageWithDims.nullable(), - model: zModelIdentifierField.nullable(), + model: zServerValidatedModelIdentifierField.nullable(), weight: z.number().gte(-1).lte(2), beginEndStepPct: zBeginEndStepPct, method: zIPMethodV2, @@ -185,7 +201,7 @@ export type CanvasInpaintMaskState = z.infer; const zControlNetConfig = z.object({ type: z.literal('controlnet'), - model: zModelIdentifierField.nullable(), + model: zServerValidatedModelIdentifierField.nullable(), weight: z.number().gte(-1).lte(2), beginEndStepPct: zBeginEndStepPct, controlMode: zControlModeV2, @@ -194,7 +210,7 @@ export type ControlNetConfig = z.infer; const zT2IAdapterConfig = z.object({ type: z.literal('t2i_adapter'), - model: zModelIdentifierField.nullable(), + model: zServerValidatedModelIdentifierField.nullable(), weight: z.number().gte(-1).lte(2), beginEndStepPct: zBeginEndStepPct, }); @@ -312,9 +328,17 @@ const zCanvasState = z.object({ optimalDimension: z.number().int().positive(), }), }); - export type CanvasState = z.infer; +export const zCanvasMetadata = z.object({ + inpaintMasks: z.array(zCanvasInpaintMaskState), + rasterLayers: z.array(zCanvasRasterLayerState), + controlLayers: z.array(zCanvasControlLayerState), + regionalGuidance: z.array(zCanvasRegionalGuidanceState), + referenceImages: z.array(zCanvasReferenceImageState), +}); +export type CanvasMetadata = z.infer; + export type StageAttrs = { x: Coordinate['x']; y: Coordinate['y']; diff --git a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts index bdf6a3bd21..4bd2436c0b 100644 --- a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts @@ -76,17 +76,17 @@ const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: * @returns A promise that resolves to the model config. * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched. */ -// export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise => { -// try { -// return await fetchModelConfig(identifier.key); -// } catch { -// try { -// return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type); -// } catch { -// throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`); -// } -// } -// }; +export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise => { + try { + return await fetchModelConfig(identifier.key); + } catch { + try { + return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type); + } catch { + throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`); + } + } +}; /** * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.