diff --git a/invokeai/frontend/web/src/features/controlLayers/store/replaceWithServerValidatedSchemas.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/replaceWithServerValidatedSchemas.test.ts deleted file mode 100644 index cd5cc152fd..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/replaceWithServerValidatedSchemas.test.ts +++ /dev/null @@ -1,120 +0,0 @@ -import type { Equals } from 'tsafe'; -import { assert } from 'tsafe'; -import { beforeEach, describe, expect, it } from 'vitest'; -import { z, ZodError } from 'zod'; - -import { - clearSchemaReplacements, - registerSchemaReplacement, - replaceWithServerValidatedSchemas, -} from './replaceWithServerValidatedSchemas'; - -describe('replaceWithServerValidatedSchemas', () => { - beforeEach(() => { - clearSchemaReplacements(); - }); - - const zFoo = z.literal('foo'); - - const zFooAsyncOK = zFoo.refine(() => { - return Promise.resolve(true); - }); - - const zFooAsyncFAIL = zFoo.refine(() => { - return Promise.resolve(false); - }); - - it('should should not alter the type of the schema', () => { - const zTest = z.object({ - foo: zFoo, - }); - registerSchemaReplacement(zFoo, zFooAsyncOK); - const _serverValidatedSchema = replaceWithServerValidatedSchemas(zTest); - - assert, z.infer>>(); - }); - - it('should pass validation when the replaced async validator passes', async () => { - const zTest = z.object({ - foo: zFoo, - }); - registerSchemaReplacement(zFoo, zFooAsyncOK); - const serverValidatedSchema = replaceWithServerValidatedSchemas(zTest); - - expect(() => serverValidatedSchema.parse({ foo: 'foo' })).toThrow( - 'Encountered Promise during synchronous parse. Use .parseAsync() instead.' - ); - await expect(serverValidatedSchema.parseAsync({ foo: 'foo' })).resolves.toEqual({ foo: 'foo' }); - }); - - it('should fail validation when the replaced async validator fails', async () => { - const zTest = z.object({ - foo: zFoo, - }); - registerSchemaReplacement(zFoo, zFooAsyncFAIL); - const serverValidatedSchema = replaceWithServerValidatedSchemas(zTest); - - expect(() => serverValidatedSchema.parse({ foo: 'foo' })).toThrow( - 'Encountered Promise during synchronous parse. Use .parseAsync() instead.' - ); - await expect(serverValidatedSchema.parseAsync({ foo: 'foo' })).rejects.toThrow(ZodError); - }); - - it('should handle deeply-nested objects', async () => { - const zNested = z.object({ - nested: z.object({ - foo: zFoo, - }), - }); - - registerSchemaReplacement(zFoo, zFooAsyncOK); - const serverValidatedSchema = replaceWithServerValidatedSchemas(zNested); - - expect(() => serverValidatedSchema.parse({ nested: { foo: 'foo' } })).toThrow( - 'Encountered Promise during synchronous parse. Use .parseAsync() instead.' - ); - - await expect(serverValidatedSchema.parseAsync({ nested: { foo: 'foo' } })).resolves.toEqual({ - nested: { foo: 'foo' }, - }); - }); - - it('should handle arrays', async () => { - const zArray = z.array(zFoo); - - registerSchemaReplacement(zFoo, zFooAsyncOK); - const serverValidatedSchema = replaceWithServerValidatedSchemas(zArray); - - expect(() => serverValidatedSchema.parse(['foo', 'foo'])).toThrow( - 'Encountered Promise during synchronous parse. Use .parseAsync() instead.' - ); - - await expect(serverValidatedSchema.parseAsync(['foo', 'foo'])).resolves.toEqual(['foo', 'foo']); - }); - - it('should handle sets', async () => { - const zSet = z.set(zFoo); - - registerSchemaReplacement(zFoo, zFooAsyncOK); - const serverValidatedSchema = replaceWithServerValidatedSchemas(zSet); - - expect(() => serverValidatedSchema.parse(new Set(['foo', 'foo']))).toThrow( - 'Encountered Promise during synchronous parse. Use .parseAsync() instead.' - ); - - await expect(serverValidatedSchema.parseAsync(new Set(['foo', 'foo']))).resolves.toEqual(new Set(['foo'])); - }); - - it('should handle records', async () => { - const zRecord = z.record(z.string(), zFoo); - - registerSchemaReplacement(zFoo, zFooAsyncOK); - const serverValidatedSchema = replaceWithServerValidatedSchemas(zRecord); - - expect(() => serverValidatedSchema.parse({ a: 'foo', b: 'foo' })).toThrow( - 'Encountered Promise during synchronous parse. Use .parseAsync() instead.' - ); - - await expect(serverValidatedSchema.parseAsync({ a: 'foo', b: 'foo' })).resolves.toEqual({ a: 'foo', b: 'foo' }); - }); -}); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/replaceWithServerValidatedSchemas.ts b/invokeai/frontend/web/src/features/controlLayers/store/replaceWithServerValidatedSchemas.ts deleted file mode 100644 index 08f4f95e15..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/replaceWithServerValidatedSchemas.ts +++ /dev/null @@ -1,240 +0,0 @@ -import { z } from 'zod'; - -/** - * Map of non-server-validated schemas to their server-validated counterparts. - * Add entries here for any schemas that need to be replaced. - */ -const schemaReplacementMap = new Map(); - -/** - * Register a schema replacement mapping. - * @param originalSchema The non-server-validated schema - * @param serverValidatedSchema The server-validated replacement schema - */ -export function registerSchemaReplacement(originalSchema: T, serverValidatedSchema: T): void { - schemaReplacementMap.set(originalSchema, serverValidatedSchema); -} - -export function clearSchemaReplacements(): void { - schemaReplacementMap.clear(); -} - -/** - * Recursively replaces non-server-validated schemas with server-validated ones. - * Handles objects, arrays, unions, intersections, and other composite types. - * - * @param schema The schema to transform - * @returns A new schema with server-validated replacements - */ -export function replaceWithServerValidatedSchemas(schema: T): T { - // Check if this schema has a direct replacement - const replacement = schemaReplacementMap.get(schema); - if (replacement) { - return replacement as T; - } - - // Access the internal definition - const def = schema._zod.def; - const type = def.type; - - // Handle different schema types - if (type === 'object') { - // For objects, recursively transform the shape - const shape = (def as any).shape; - if (!shape) { - return schema; - } - - const newShape: Record = {}; - for (const key in shape) { - newShape[key] = replaceWithServerValidatedSchemas(shape[key]); - } - - // Create a new object with the transformed shape - const newSchema = z.object(newShape); - - // Preserve the original object configuration (strict/strip/passthrough) - const config = (def as any).config; - if (config?.type === 'strict') { - return newSchema.strict(); - } else if (config?.type === 'loose') { - return newSchema.passthrough(); - } - - return newSchema; - } - - if (type === 'array') { - // For arrays, transform the element type - const element = (def as any).element; - if (!element) { - return schema; - } - - const newElement = replaceWithServerValidatedSchemas(element); - return z.array(newElement); - } - - if (type === 'union') { - // For unions, transform all options - const options = (def as any).options; - if (!options || !Array.isArray(options)) { - return schema; - } - - const newOptions = options.map((opt) => replaceWithServerValidatedSchemas(opt)); - return z.union(newOptions as [z.ZodType, z.ZodType, ...z.ZodType[]]); - } - - if (type === 'intersection') { - // For intersections, transform both sides - const left = (def as any).left; - const right = (def as any).right; - if (!left || !right) { - return schema; - } - - const newLeft = replaceWithServerValidatedSchemas(left); - const newRight = replaceWithServerValidatedSchemas(right); - return z.intersection(newLeft, newRight); - } - - if (type === 'optional') { - // For optional, transform the inner type - const inner = (def as any).inner; - if (!inner) { - return schema; - } - - const newInner = replaceWithServerValidatedSchemas(inner); - return newInner.optional(); - } - - if (type === 'nullable') { - // For nullable, transform the inner type - const inner = (def as any).inner; - if (!inner) { - return schema; - } - - const newInner = replaceWithServerValidatedSchemas(inner); - return newInner.nullable(); - } - - if (type === 'default') { - // For default, transform the inner type and preserve default value - const inner = (def as any).inner; - const defaultValue = (def as any).defaultValue; - if (!inner) { - return schema; - } - - const newInner = replaceWithServerValidatedSchemas(inner); - return newInner.default(defaultValue); - } - - if (type === 'catch') { - // For catch, transform the inner type and preserve catch value - const inner = (def as any).inner; - const catchValue = (def as any).catchValue; - if (!inner) { - return schema; - } - - const newInner = replaceWithServerValidatedSchemas(inner); - return newInner.catch(catchValue); - } - - if (type === 'readonly') { - // For readonly, transform the inner type - const inner = (def as any).inner; - if (!inner) { - return schema; - } - - const newInner = replaceWithServerValidatedSchemas(inner); - return newInner.readonly(); - } - - if (type === 'promise') { - // For promise, transform the inner type - const inner = (def as any).inner; - if (!inner) { - return schema; - } - - const newInner = replaceWithServerValidatedSchemas(inner); - return z.promise(newInner); - } - - if (type === 'lazy') { - // For lazy schemas, we need to wrap the getter function - const getter = (def as any).getter; - if (!getter) { - return schema; - } - - return z.lazy(() => replaceWithServerValidatedSchemas(getter())); - } - - if (type === 'record') { - // For records, transform the value type - const valueType = (def as any).valueType; - const keyType = (def as any).keyType; - if (!valueType) { - return schema; - } - - const newValueType = replaceWithServerValidatedSchemas(valueType); - - if (keyType) { - return z.record(keyType, newValueType); - } - return z.record(newValueType); - } - - if (type === 'map') { - // For maps, transform key and value types - const keyType = (def as any).keyType; - const valueType = (def as any).valueType; - if (!keyType || !valueType) { - return schema; - } - - const newKeyType = replaceWithServerValidatedSchemas(keyType); - const newValueType = replaceWithServerValidatedSchemas(valueType); - return z.map(newKeyType, newValueType); - } - - if (type === 'set') { - // For sets, transform the value type - const valueType = (def as any).valueType; - if (!valueType) { - return schema; - } - - const newValueType = replaceWithServerValidatedSchemas(valueType); - return z.set(newValueType); - } - - if (type === 'tuple') { - // For tuples, transform each item - const items = (def as any).items; - if (!items || !Array.isArray(items)) { - return schema; - } - - const newItems = items.map((item) => replaceWithServerValidatedSchemas(item)); - return z.tuple(newItems as [z.ZodType, ...z.ZodType[]]); - } - - if (type === 'transform' || type === 'pipe') { - // For transforms and pipes, we need to handle carefully - // In v4, these might have different internal structure - // For now, return as-is since transforming these could break functionality - return schema; - } - - // For primitive types and any unhandled types, return as-is - return schema; -} diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index bfdc2666af..bb703382ba 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -1,6 +1,5 @@ import { deepClone } from 'common/util/deepClone'; import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types'; -import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers'; import type { ProgressImage } from 'features/nodes/types/common'; import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; import { @@ -28,32 +27,17 @@ import { zParameterT5EncoderModel, zParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; -import { getImageDTOSafe } from 'services/api/endpoints/images'; import type { JsonObject } from 'type-fest'; import { z } from 'zod'; const zId = z.string().min(1); const zName = z.string().min(1).nullable(); -export const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => { - try { - await fetchModelConfigByIdentifier(modelIdentifier); - return true; - } catch { - return false; - } -}); - export const zImageWithDims = z.object({ image_name: z.string(), width: z.number().int().positive(), height: z.number().int().positive(), }); -export const zServerValidatedImageWithDims = zImageWithDims.refine(async (v) => { - const { image_name } = v; - const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true }); - return imageDTO !== null; -}); export type ImageWithDims = z.infer; const zImageWithDimsDataURL = z.object({ diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index 34e032473f..ca66be6e24 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -89,6 +89,7 @@ import { t } from 'i18next'; import type { ComponentType } from 'react'; import { useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { imagesApi } from 'services/api/endpoints/images'; import { modelsApi } from 'services/api/endpoints/models'; import type { AnyModelConfig, ModelType } from 'services/api/types'; import { assert } from 'tsafe'; @@ -787,11 +788,55 @@ const LoRAs: CollectionMetadataHandler = { const CanvasLayers: SingleMetadataHandler = { [SingleMetadataKey]: true, type: 'CanvasLayers', - parse: async (metadata) => { + parse: async (metadata, store) => { const raw = getProperty(metadata, 'canvas_v2_metadata'); // This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in // the zImageWithDims schema. const parsed = await zCanvasMetadata.parseAsync(raw); + + for (const entity of parsed.controlLayers) { + if (entity.controlAdapter.model) { + await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store); + } + for (const object of entity.objects) { + if (object.type === 'image' && 'image_name' in object.image) { + await throwIfImageDoesNotExist(object.image.image_name, store); + } + } + } + + for (const entity of parsed.inpaintMasks) { + for (const object of entity.objects) { + if (object.type === 'image' && 'image_name' in object.image) { + await throwIfImageDoesNotExist(object.image.image_name, store); + } + } + } + + for (const entity of parsed.rasterLayers) { + for (const object of entity.objects) { + if (object.type === 'image' && 'image_name' in object.image) { + await throwIfImageDoesNotExist(object.image.image_name, store); + } + } + } + + for (const entity of parsed.regionalGuidance) { + for (const object of entity.objects) { + if (object.type === 'image' && 'image_name' in object.image) { + await throwIfImageDoesNotExist(object.image.image_name, store); + } + } + for (const refImage of entity.referenceImages) { + if (refImage.config.image) { + await throwIfImageDoesNotExist(refImage.config.image.image_name, store); + } + if (refImage.config.model) { + await throwIfModelDoesNotExist(refImage.config.model.key, store); + } + } + } + return Promise.resolve(parsed); }, recall: (value, store) => { @@ -824,27 +869,39 @@ const CanvasLayers: SingleMetadataHandler = { const RefImages: CollectionMetadataHandler = { [CollectionMetadataKey]: true, type: 'RefImages', - parse: async (metadata) => { + parse: async (metadata, store) => { + let parsed: RefImageState[] | null = null; try { // First attempt to parse from the v6 slot const raw = getProperty(metadata, 'ref_images'); - // This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in - // the zImageWithDims schema. - const parsed = await z.array(zRefImageState).parseAsync(raw); - return Promise.resolve(parsed); + parsed = z.array(zRefImageState).parse(raw); } catch { // Fall back to extracting from canvas metadata] const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities'); // This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in // the zImageWithDims schema. const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw); - const parsed: RefImageState[] = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({ + parsed = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({ id, config: ipAdapter, isEnabled, })); - return parsed; } + + if (!parsed) { + throw new Error('No valid reference images found in metadata'); + } + + for (const refImage of parsed) { + if (refImage.config.image) { + await throwIfImageDoesNotExist(refImage.config.image.image_name, store); + } + if (refImage.config.model) { + await throwIfModelDoesNotExist(refImage.config.model.key, store); + } + } + + return parsed; }, recall: (value, store) => { const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') })); @@ -1241,3 +1298,19 @@ const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppSt } return candidate.base === base; }; + +const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise => { + try { + await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap(); + } catch { + throw new Error(`Image with name ${name} does not exist`); + } +}; + +const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise => { + try { + await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false })); + } catch { + throw new Error(`Model with key ${key} does not exist`); + } +};