mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(ui): just manually validate async stuff
This commit is contained in:
@@ -1,120 +0,0 @@
|
|||||||
import type { Equals } from 'tsafe';
|
|
||||||
import { assert } from 'tsafe';
|
|
||||||
import { beforeEach, describe, expect, it } from 'vitest';
|
|
||||||
import { z, ZodError } from 'zod';
|
|
||||||
|
|
||||||
import {
|
|
||||||
clearSchemaReplacements,
|
|
||||||
registerSchemaReplacement,
|
|
||||||
replaceWithServerValidatedSchemas,
|
|
||||||
} from './replaceWithServerValidatedSchemas';
|
|
||||||
|
|
||||||
describe('replaceWithServerValidatedSchemas', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
clearSchemaReplacements();
|
|
||||||
});
|
|
||||||
|
|
||||||
const zFoo = z.literal('foo');
|
|
||||||
|
|
||||||
const zFooAsyncOK = zFoo.refine(() => {
|
|
||||||
return Promise.resolve(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
const zFooAsyncFAIL = zFoo.refine(() => {
|
|
||||||
return Promise.resolve(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should should not alter the type of the schema', () => {
|
|
||||||
const zTest = z.object({
|
|
||||||
foo: zFoo,
|
|
||||||
});
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncOK);
|
|
||||||
const _serverValidatedSchema = replaceWithServerValidatedSchemas(zTest);
|
|
||||||
|
|
||||||
assert<Equals<z.infer<typeof _serverValidatedSchema>, z.infer<typeof zTest>>>();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should pass validation when the replaced async validator passes', async () => {
|
|
||||||
const zTest = z.object({
|
|
||||||
foo: zFoo,
|
|
||||||
});
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncOK);
|
|
||||||
const serverValidatedSchema = replaceWithServerValidatedSchemas(zTest);
|
|
||||||
|
|
||||||
expect(() => serverValidatedSchema.parse({ foo: 'foo' })).toThrow(
|
|
||||||
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
|
|
||||||
);
|
|
||||||
await expect(serverValidatedSchema.parseAsync({ foo: 'foo' })).resolves.toEqual({ foo: 'foo' });
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should fail validation when the replaced async validator fails', async () => {
|
|
||||||
const zTest = z.object({
|
|
||||||
foo: zFoo,
|
|
||||||
});
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncFAIL);
|
|
||||||
const serverValidatedSchema = replaceWithServerValidatedSchemas(zTest);
|
|
||||||
|
|
||||||
expect(() => serverValidatedSchema.parse({ foo: 'foo' })).toThrow(
|
|
||||||
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
|
|
||||||
);
|
|
||||||
await expect(serverValidatedSchema.parseAsync({ foo: 'foo' })).rejects.toThrow(ZodError);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle deeply-nested objects', async () => {
|
|
||||||
const zNested = z.object({
|
|
||||||
nested: z.object({
|
|
||||||
foo: zFoo,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncOK);
|
|
||||||
const serverValidatedSchema = replaceWithServerValidatedSchemas(zNested);
|
|
||||||
|
|
||||||
expect(() => serverValidatedSchema.parse({ nested: { foo: 'foo' } })).toThrow(
|
|
||||||
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
|
|
||||||
);
|
|
||||||
|
|
||||||
await expect(serverValidatedSchema.parseAsync({ nested: { foo: 'foo' } })).resolves.toEqual({
|
|
||||||
nested: { foo: 'foo' },
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle arrays', async () => {
|
|
||||||
const zArray = z.array(zFoo);
|
|
||||||
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncOK);
|
|
||||||
const serverValidatedSchema = replaceWithServerValidatedSchemas(zArray);
|
|
||||||
|
|
||||||
expect(() => serverValidatedSchema.parse(['foo', 'foo'])).toThrow(
|
|
||||||
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
|
|
||||||
);
|
|
||||||
|
|
||||||
await expect(serverValidatedSchema.parseAsync(['foo', 'foo'])).resolves.toEqual(['foo', 'foo']);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle sets', async () => {
|
|
||||||
const zSet = z.set(zFoo);
|
|
||||||
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncOK);
|
|
||||||
const serverValidatedSchema = replaceWithServerValidatedSchemas(zSet);
|
|
||||||
|
|
||||||
expect(() => serverValidatedSchema.parse(new Set(['foo', 'foo']))).toThrow(
|
|
||||||
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
|
|
||||||
);
|
|
||||||
|
|
||||||
await expect(serverValidatedSchema.parseAsync(new Set(['foo', 'foo']))).resolves.toEqual(new Set(['foo']));
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle records', async () => {
|
|
||||||
const zRecord = z.record(z.string(), zFoo);
|
|
||||||
|
|
||||||
registerSchemaReplacement(zFoo, zFooAsyncOK);
|
|
||||||
const serverValidatedSchema = replaceWithServerValidatedSchemas(zRecord);
|
|
||||||
|
|
||||||
expect(() => serverValidatedSchema.parse({ a: 'foo', b: 'foo' })).toThrow(
|
|
||||||
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
|
|
||||||
);
|
|
||||||
|
|
||||||
await expect(serverValidatedSchema.parseAsync({ a: 'foo', b: 'foo' })).resolves.toEqual({ a: 'foo', b: 'foo' });
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
import { z } from 'zod';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map of non-server-validated schemas to their server-validated counterparts.
|
|
||||||
* Add entries here for any schemas that need to be replaced.
|
|
||||||
*/
|
|
||||||
const schemaReplacementMap = new Map<z.ZodType, z.ZodType>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Register a schema replacement mapping.
|
|
||||||
* @param originalSchema The non-server-validated schema
|
|
||||||
* @param serverValidatedSchema The server-validated replacement schema
|
|
||||||
*/
|
|
||||||
export function registerSchemaReplacement<T extends z.ZodType>(originalSchema: T, serverValidatedSchema: T): void {
|
|
||||||
schemaReplacementMap.set(originalSchema, serverValidatedSchema);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function clearSchemaReplacements(): void {
|
|
||||||
schemaReplacementMap.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recursively replaces non-server-validated schemas with server-validated ones.
|
|
||||||
* Handles objects, arrays, unions, intersections, and other composite types.
|
|
||||||
*
|
|
||||||
* @param schema The schema to transform
|
|
||||||
* @returns A new schema with server-validated replacements
|
|
||||||
*/
|
|
||||||
export function replaceWithServerValidatedSchemas<T extends z.ZodType>(schema: T): T {
|
|
||||||
// Check if this schema has a direct replacement
|
|
||||||
const replacement = schemaReplacementMap.get(schema);
|
|
||||||
if (replacement) {
|
|
||||||
return replacement as T;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Access the internal definition
|
|
||||||
const def = schema._zod.def;
|
|
||||||
const type = def.type;
|
|
||||||
|
|
||||||
// Handle different schema types
|
|
||||||
if (type === 'object') {
|
|
||||||
// For objects, recursively transform the shape
|
|
||||||
const shape = (def as any).shape;
|
|
||||||
if (!shape) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newShape: Record<string, z.ZodType> = {};
|
|
||||||
for (const key in shape) {
|
|
||||||
newShape[key] = replaceWithServerValidatedSchemas(shape[key]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new object with the transformed shape
|
|
||||||
const newSchema = z.object(newShape);
|
|
||||||
|
|
||||||
// Preserve the original object configuration (strict/strip/passthrough)
|
|
||||||
const config = (def as any).config;
|
|
||||||
if (config?.type === 'strict') {
|
|
||||||
return newSchema.strict();
|
|
||||||
} else if (config?.type === 'loose') {
|
|
||||||
return newSchema.passthrough();
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSchema;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'array') {
|
|
||||||
// For arrays, transform the element type
|
|
||||||
const element = (def as any).element;
|
|
||||||
if (!element) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newElement = replaceWithServerValidatedSchemas(element);
|
|
||||||
return z.array(newElement);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'union') {
|
|
||||||
// For unions, transform all options
|
|
||||||
const options = (def as any).options;
|
|
||||||
if (!options || !Array.isArray(options)) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newOptions = options.map((opt) => replaceWithServerValidatedSchemas(opt));
|
|
||||||
return z.union(newOptions as [z.ZodType, z.ZodType, ...z.ZodType[]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'intersection') {
|
|
||||||
// For intersections, transform both sides
|
|
||||||
const left = (def as any).left;
|
|
||||||
const right = (def as any).right;
|
|
||||||
if (!left || !right) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newLeft = replaceWithServerValidatedSchemas(left);
|
|
||||||
const newRight = replaceWithServerValidatedSchemas(right);
|
|
||||||
return z.intersection(newLeft, newRight);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'optional') {
|
|
||||||
// For optional, transform the inner type
|
|
||||||
const inner = (def as any).inner;
|
|
||||||
if (!inner) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newInner = replaceWithServerValidatedSchemas(inner);
|
|
||||||
return newInner.optional();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'nullable') {
|
|
||||||
// For nullable, transform the inner type
|
|
||||||
const inner = (def as any).inner;
|
|
||||||
if (!inner) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newInner = replaceWithServerValidatedSchemas(inner);
|
|
||||||
return newInner.nullable();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'default') {
|
|
||||||
// For default, transform the inner type and preserve default value
|
|
||||||
const inner = (def as any).inner;
|
|
||||||
const defaultValue = (def as any).defaultValue;
|
|
||||||
if (!inner) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newInner = replaceWithServerValidatedSchemas(inner);
|
|
||||||
return newInner.default(defaultValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'catch') {
|
|
||||||
// For catch, transform the inner type and preserve catch value
|
|
||||||
const inner = (def as any).inner;
|
|
||||||
const catchValue = (def as any).catchValue;
|
|
||||||
if (!inner) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newInner = replaceWithServerValidatedSchemas(inner);
|
|
||||||
return newInner.catch(catchValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'readonly') {
|
|
||||||
// For readonly, transform the inner type
|
|
||||||
const inner = (def as any).inner;
|
|
||||||
if (!inner) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newInner = replaceWithServerValidatedSchemas(inner);
|
|
||||||
return newInner.readonly();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'promise') {
|
|
||||||
// For promise, transform the inner type
|
|
||||||
const inner = (def as any).inner;
|
|
||||||
if (!inner) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newInner = replaceWithServerValidatedSchemas(inner);
|
|
||||||
return z.promise(newInner);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'lazy') {
|
|
||||||
// For lazy schemas, we need to wrap the getter function
|
|
||||||
const getter = (def as any).getter;
|
|
||||||
if (!getter) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
return z.lazy(() => replaceWithServerValidatedSchemas(getter()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'record') {
|
|
||||||
// For records, transform the value type
|
|
||||||
const valueType = (def as any).valueType;
|
|
||||||
const keyType = (def as any).keyType;
|
|
||||||
if (!valueType) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newValueType = replaceWithServerValidatedSchemas(valueType);
|
|
||||||
|
|
||||||
if (keyType) {
|
|
||||||
return z.record(keyType, newValueType);
|
|
||||||
}
|
|
||||||
return z.record(newValueType);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'map') {
|
|
||||||
// For maps, transform key and value types
|
|
||||||
const keyType = (def as any).keyType;
|
|
||||||
const valueType = (def as any).valueType;
|
|
||||||
if (!keyType || !valueType) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newKeyType = replaceWithServerValidatedSchemas(keyType);
|
|
||||||
const newValueType = replaceWithServerValidatedSchemas(valueType);
|
|
||||||
return z.map(newKeyType, newValueType);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'set') {
|
|
||||||
// For sets, transform the value type
|
|
||||||
const valueType = (def as any).valueType;
|
|
||||||
if (!valueType) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newValueType = replaceWithServerValidatedSchemas(valueType);
|
|
||||||
return z.set(newValueType);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'tuple') {
|
|
||||||
// For tuples, transform each item
|
|
||||||
const items = (def as any).items;
|
|
||||||
if (!items || !Array.isArray(items)) {
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newItems = items.map((item) => replaceWithServerValidatedSchemas(item));
|
|
||||||
return z.tuple(newItems as [z.ZodType, ...z.ZodType[]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'transform' || type === 'pipe') {
|
|
||||||
// For transforms and pipes, we need to handle carefully
|
|
||||||
// In v4, these might have different internal structure
|
|
||||||
// For now, return as-is since transforming these could break functionality
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For primitive types and any unhandled types, return as-is
|
|
||||||
return schema;
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import { deepClone } from 'common/util/deepClone';
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||||
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
|
|
||||||
import type { ProgressImage } from 'features/nodes/types/common';
|
import type { ProgressImage } from 'features/nodes/types/common';
|
||||||
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
|
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import {
|
import {
|
||||||
@@ -28,32 +27,17 @@ import {
|
|||||||
zParameterT5EncoderModel,
|
zParameterT5EncoderModel,
|
||||||
zParameterVAEModel,
|
zParameterVAEModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
|
||||||
import type { JsonObject } from 'type-fest';
|
import type { JsonObject } from 'type-fest';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
const zId = z.string().min(1);
|
const zId = z.string().min(1);
|
||||||
const zName = z.string().min(1).nullable();
|
const zName = z.string().min(1).nullable();
|
||||||
|
|
||||||
export const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
|
|
||||||
try {
|
|
||||||
await fetchModelConfigByIdentifier(modelIdentifier);
|
|
||||||
return true;
|
|
||||||
} catch {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
export const zImageWithDims = z.object({
|
export const zImageWithDims = z.object({
|
||||||
image_name: z.string(),
|
image_name: z.string(),
|
||||||
width: z.number().int().positive(),
|
width: z.number().int().positive(),
|
||||||
height: z.number().int().positive(),
|
height: z.number().int().positive(),
|
||||||
});
|
});
|
||||||
export const zServerValidatedImageWithDims = zImageWithDims.refine(async (v) => {
|
|
||||||
const { image_name } = v;
|
|
||||||
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
|
|
||||||
return imageDTO !== null;
|
|
||||||
});
|
|
||||||
export type ImageWithDims = z.infer<typeof zImageWithDims>;
|
export type ImageWithDims = z.infer<typeof zImageWithDims>;
|
||||||
|
|
||||||
const zImageWithDimsDataURL = z.object({
|
const zImageWithDimsDataURL = z.object({
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ import { t } from 'i18next';
|
|||||||
import type { ComponentType } from 'react';
|
import type { ComponentType } from 'react';
|
||||||
import { useCallback, useEffect, useState } from 'react';
|
import { useCallback, useEffect, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@@ -787,11 +788,55 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
|
|||||||
const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
|
const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
|
||||||
[SingleMetadataKey]: true,
|
[SingleMetadataKey]: true,
|
||||||
type: 'CanvasLayers',
|
type: 'CanvasLayers',
|
||||||
parse: async (metadata) => {
|
parse: async (metadata, store) => {
|
||||||
const raw = getProperty(metadata, 'canvas_v2_metadata');
|
const raw = getProperty(metadata, 'canvas_v2_metadata');
|
||||||
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
||||||
// the zImageWithDims schema.
|
// the zImageWithDims schema.
|
||||||
const parsed = await zCanvasMetadata.parseAsync(raw);
|
const parsed = await zCanvasMetadata.parseAsync(raw);
|
||||||
|
|
||||||
|
for (const entity of parsed.controlLayers) {
|
||||||
|
if (entity.controlAdapter.model) {
|
||||||
|
await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store);
|
||||||
|
}
|
||||||
|
for (const object of entity.objects) {
|
||||||
|
if (object.type === 'image' && 'image_name' in object.image) {
|
||||||
|
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const entity of parsed.inpaintMasks) {
|
||||||
|
for (const object of entity.objects) {
|
||||||
|
if (object.type === 'image' && 'image_name' in object.image) {
|
||||||
|
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const entity of parsed.rasterLayers) {
|
||||||
|
for (const object of entity.objects) {
|
||||||
|
if (object.type === 'image' && 'image_name' in object.image) {
|
||||||
|
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const entity of parsed.regionalGuidance) {
|
||||||
|
for (const object of entity.objects) {
|
||||||
|
if (object.type === 'image' && 'image_name' in object.image) {
|
||||||
|
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const refImage of entity.referenceImages) {
|
||||||
|
if (refImage.config.image) {
|
||||||
|
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
|
||||||
|
}
|
||||||
|
if (refImage.config.model) {
|
||||||
|
await throwIfModelDoesNotExist(refImage.config.model.key, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Promise.resolve(parsed);
|
return Promise.resolve(parsed);
|
||||||
},
|
},
|
||||||
recall: (value, store) => {
|
recall: (value, store) => {
|
||||||
@@ -824,27 +869,39 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
|
|||||||
const RefImages: CollectionMetadataHandler<RefImageState[]> = {
|
const RefImages: CollectionMetadataHandler<RefImageState[]> = {
|
||||||
[CollectionMetadataKey]: true,
|
[CollectionMetadataKey]: true,
|
||||||
type: 'RefImages',
|
type: 'RefImages',
|
||||||
parse: async (metadata) => {
|
parse: async (metadata, store) => {
|
||||||
|
let parsed: RefImageState[] | null = null;
|
||||||
try {
|
try {
|
||||||
// First attempt to parse from the v6 slot
|
// First attempt to parse from the v6 slot
|
||||||
const raw = getProperty(metadata, 'ref_images');
|
const raw = getProperty(metadata, 'ref_images');
|
||||||
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
parsed = z.array(zRefImageState).parse(raw);
|
||||||
// the zImageWithDims schema.
|
|
||||||
const parsed = await z.array(zRefImageState).parseAsync(raw);
|
|
||||||
return Promise.resolve(parsed);
|
|
||||||
} catch {
|
} catch {
|
||||||
// Fall back to extracting from canvas metadata]
|
// Fall back to extracting from canvas metadata]
|
||||||
const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities');
|
const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities');
|
||||||
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
||||||
// the zImageWithDims schema.
|
// the zImageWithDims schema.
|
||||||
const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw);
|
const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw);
|
||||||
const parsed: RefImageState[] = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
|
parsed = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
|
||||||
id,
|
id,
|
||||||
config: ipAdapter,
|
config: ipAdapter,
|
||||||
isEnabled,
|
isEnabled,
|
||||||
}));
|
}));
|
||||||
return parsed;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!parsed) {
|
||||||
|
throw new Error('No valid reference images found in metadata');
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const refImage of parsed) {
|
||||||
|
if (refImage.config.image) {
|
||||||
|
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
|
||||||
|
}
|
||||||
|
if (refImage.config.model) {
|
||||||
|
await throwIfModelDoesNotExist(refImage.config.model.key, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed;
|
||||||
},
|
},
|
||||||
recall: (value, store) => {
|
recall: (value, store) => {
|
||||||
const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') }));
|
const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') }));
|
||||||
@@ -1241,3 +1298,19 @@ const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppSt
|
|||||||
}
|
}
|
||||||
return candidate.base === base;
|
return candidate.base === base;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise<void> => {
|
||||||
|
try {
|
||||||
|
await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap();
|
||||||
|
} catch {
|
||||||
|
throw new Error(`Image with name ${name} does not exist`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise<void> => {
|
||||||
|
try {
|
||||||
|
await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
|
||||||
|
} catch {
|
||||||
|
throw new Error(`Model with key ${key} does not exist`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user