refactor(ui): work around zod async validation issue

This commit is contained in:
psychedelicious
2025-07-25 15:59:15 +10:00
parent 82cdfd83e4
commit a8662953fc
3 changed files with 380 additions and 21 deletions

View File

@@ -0,0 +1,120 @@
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { beforeEach, describe, expect, it } from 'vitest';
import { z, ZodError } from 'zod';
import {
clearSchemaReplacements,
registerSchemaReplacement,
replaceWithServerValidatedSchemas,
} from './replaceWithServerValidatedSchemas';
describe('replaceWithServerValidatedSchemas', () => {
beforeEach(() => {
clearSchemaReplacements();
});
const zFoo = z.literal('foo');
const zFooAsyncOK = zFoo.refine(() => {
return Promise.resolve(true);
});
const zFooAsyncFAIL = zFoo.refine(() => {
return Promise.resolve(false);
});
it('should should not alter the type of the schema', () => {
const zTest = z.object({
foo: zFoo,
});
registerSchemaReplacement(zFoo, zFooAsyncOK);
const _serverValidatedSchema = replaceWithServerValidatedSchemas(zTest);
assert<Equals<z.infer<typeof _serverValidatedSchema>, z.infer<typeof zTest>>>();
});
it('should pass validation when the replaced async validator passes', async () => {
const zTest = z.object({
foo: zFoo,
});
registerSchemaReplacement(zFoo, zFooAsyncOK);
const serverValidatedSchema = replaceWithServerValidatedSchemas(zTest);
expect(() => serverValidatedSchema.parse({ foo: 'foo' })).toThrow(
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
);
await expect(serverValidatedSchema.parseAsync({ foo: 'foo' })).resolves.toEqual({ foo: 'foo' });
});
it('should fail validation when the replaced async validator fails', async () => {
const zTest = z.object({
foo: zFoo,
});
registerSchemaReplacement(zFoo, zFooAsyncFAIL);
const serverValidatedSchema = replaceWithServerValidatedSchemas(zTest);
expect(() => serverValidatedSchema.parse({ foo: 'foo' })).toThrow(
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
);
await expect(serverValidatedSchema.parseAsync({ foo: 'foo' })).rejects.toThrow(ZodError);
});
it('should handle deeply-nested objects', async () => {
const zNested = z.object({
nested: z.object({
foo: zFoo,
}),
});
registerSchemaReplacement(zFoo, zFooAsyncOK);
const serverValidatedSchema = replaceWithServerValidatedSchemas(zNested);
expect(() => serverValidatedSchema.parse({ nested: { foo: 'foo' } })).toThrow(
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
);
await expect(serverValidatedSchema.parseAsync({ nested: { foo: 'foo' } })).resolves.toEqual({
nested: { foo: 'foo' },
});
});
it('should handle arrays', async () => {
const zArray = z.array(zFoo);
registerSchemaReplacement(zFoo, zFooAsyncOK);
const serverValidatedSchema = replaceWithServerValidatedSchemas(zArray);
expect(() => serverValidatedSchema.parse(['foo', 'foo'])).toThrow(
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
);
await expect(serverValidatedSchema.parseAsync(['foo', 'foo'])).resolves.toEqual(['foo', 'foo']);
});
it('should handle sets', async () => {
const zSet = z.set(zFoo);
registerSchemaReplacement(zFoo, zFooAsyncOK);
const serverValidatedSchema = replaceWithServerValidatedSchemas(zSet);
expect(() => serverValidatedSchema.parse(new Set(['foo', 'foo']))).toThrow(
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
);
await expect(serverValidatedSchema.parseAsync(new Set(['foo', 'foo']))).resolves.toEqual(new Set(['foo']));
});
it('should handle records', async () => {
const zRecord = z.record(z.string(), zFoo);
registerSchemaReplacement(zFoo, zFooAsyncOK);
const serverValidatedSchema = replaceWithServerValidatedSchemas(zRecord);
expect(() => serverValidatedSchema.parse({ a: 'foo', b: 'foo' })).toThrow(
'Encountered Promise during synchronous parse. Use .parseAsync() instead.'
);
await expect(serverValidatedSchema.parseAsync({ a: 'foo', b: 'foo' })).resolves.toEqual({ a: 'foo', b: 'foo' });
});
});

View File

@@ -0,0 +1,240 @@
import { z } from 'zod';
/**
* Map of non-server-validated schemas to their server-validated counterparts.
* Add entries here for any schemas that need to be replaced.
*/
const schemaReplacementMap = new Map<z.ZodType, z.ZodType>();
/**
* Register a schema replacement mapping.
* @param originalSchema The non-server-validated schema
* @param serverValidatedSchema The server-validated replacement schema
*/
export function registerSchemaReplacement<T extends z.ZodType>(originalSchema: T, serverValidatedSchema: T): void {
schemaReplacementMap.set(originalSchema, serverValidatedSchema);
}
export function clearSchemaReplacements(): void {
schemaReplacementMap.clear();
}
/**
* Recursively replaces non-server-validated schemas with server-validated ones.
* Handles objects, arrays, unions, intersections, and other composite types.
*
* @param schema The schema to transform
* @returns A new schema with server-validated replacements
*/
export function replaceWithServerValidatedSchemas<T extends z.ZodType>(schema: T): T {
// Check if this schema has a direct replacement
const replacement = schemaReplacementMap.get(schema);
if (replacement) {
return replacement as T;
}
// Access the internal definition
const def = schema._zod.def;
const type = def.type;
// Handle different schema types
if (type === 'object') {
// For objects, recursively transform the shape
const shape = (def as any).shape;
if (!shape) {
return schema;
}
const newShape: Record<string, z.ZodType> = {};
for (const key in shape) {
newShape[key] = replaceWithServerValidatedSchemas(shape[key]);
}
// Create a new object with the transformed shape
const newSchema = z.object(newShape);
// Preserve the original object configuration (strict/strip/passthrough)
const config = (def as any).config;
if (config?.type === 'strict') {
return newSchema.strict();
} else if (config?.type === 'loose') {
return newSchema.passthrough();
}
return newSchema;
}
if (type === 'array') {
// For arrays, transform the element type
const element = (def as any).element;
if (!element) {
return schema;
}
const newElement = replaceWithServerValidatedSchemas(element);
return z.array(newElement);
}
if (type === 'union') {
// For unions, transform all options
const options = (def as any).options;
if (!options || !Array.isArray(options)) {
return schema;
}
const newOptions = options.map((opt) => replaceWithServerValidatedSchemas(opt));
return z.union(newOptions as [z.ZodType, z.ZodType, ...z.ZodType[]]);
}
if (type === 'intersection') {
// For intersections, transform both sides
const left = (def as any).left;
const right = (def as any).right;
if (!left || !right) {
return schema;
}
const newLeft = replaceWithServerValidatedSchemas(left);
const newRight = replaceWithServerValidatedSchemas(right);
return z.intersection(newLeft, newRight);
}
if (type === 'optional') {
// For optional, transform the inner type
const inner = (def as any).inner;
if (!inner) {
return schema;
}
const newInner = replaceWithServerValidatedSchemas(inner);
return newInner.optional();
}
if (type === 'nullable') {
// For nullable, transform the inner type
const inner = (def as any).inner;
if (!inner) {
return schema;
}
const newInner = replaceWithServerValidatedSchemas(inner);
return newInner.nullable();
}
if (type === 'default') {
// For default, transform the inner type and preserve default value
const inner = (def as any).inner;
const defaultValue = (def as any).defaultValue;
if (!inner) {
return schema;
}
const newInner = replaceWithServerValidatedSchemas(inner);
return newInner.default(defaultValue);
}
if (type === 'catch') {
// For catch, transform the inner type and preserve catch value
const inner = (def as any).inner;
const catchValue = (def as any).catchValue;
if (!inner) {
return schema;
}
const newInner = replaceWithServerValidatedSchemas(inner);
return newInner.catch(catchValue);
}
if (type === 'readonly') {
// For readonly, transform the inner type
const inner = (def as any).inner;
if (!inner) {
return schema;
}
const newInner = replaceWithServerValidatedSchemas(inner);
return newInner.readonly();
}
if (type === 'promise') {
// For promise, transform the inner type
const inner = (def as any).inner;
if (!inner) {
return schema;
}
const newInner = replaceWithServerValidatedSchemas(inner);
return z.promise(newInner);
}
if (type === 'lazy') {
// For lazy schemas, we need to wrap the getter function
const getter = (def as any).getter;
if (!getter) {
return schema;
}
return z.lazy(() => replaceWithServerValidatedSchemas(getter()));
}
if (type === 'record') {
// For records, transform the value type
const valueType = (def as any).valueType;
const keyType = (def as any).keyType;
if (!valueType) {
return schema;
}
const newValueType = replaceWithServerValidatedSchemas(valueType);
if (keyType) {
return z.record(keyType, newValueType);
}
return z.record(newValueType);
}
if (type === 'map') {
// For maps, transform key and value types
const keyType = (def as any).keyType;
const valueType = (def as any).valueType;
if (!keyType || !valueType) {
return schema;
}
const newKeyType = replaceWithServerValidatedSchemas(keyType);
const newValueType = replaceWithServerValidatedSchemas(valueType);
return z.map(newKeyType, newValueType);
}
if (type === 'set') {
// For sets, transform the value type
const valueType = (def as any).valueType;
if (!valueType) {
return schema;
}
const newValueType = replaceWithServerValidatedSchemas(valueType);
return z.set(newValueType);
}
if (type === 'tuple') {
// For tuples, transform each item
const items = (def as any).items;
if (!items || !Array.isArray(items)) {
return schema;
}
const newItems = items.map((item) => replaceWithServerValidatedSchemas(item));
return z.tuple(newItems as [z.ZodType, ...z.ZodType[]]);
}
if (type === 'transform' || type === 'pipe') {
// For transforms and pipes, we need to handle carefully
// In v4, these might have different internal structure
// For now, return as-is since transforming these could break functionality
return schema;
}
// For primitive types and any unhandled types, return as-is
return schema;
}

View File

@@ -35,7 +35,7 @@ import { z } from 'zod';
const zId = z.string().min(1);
const zName = z.string().min(1).nullable();
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
export const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
try {
await fetchModelConfigByIdentifier(modelIdentifier);
return true;
@@ -44,17 +44,16 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async
}
});
export const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export const zImageWithDims = z.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
});
export const zServerValidatedImageWithDims = zImageWithDims.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zImageWithDimsDataURL = z.object({
@@ -249,10 +248,10 @@ const zCanvasObjectState = z.union([
]);
export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
export const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
@@ -267,7 +266,7 @@ export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
@@ -280,14 +279,14 @@ const zChatGPT4oReferenceImageConfig = z.object({
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -359,7 +358,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
const zControlNetConfig = z.object({
type: z.literal('controlnet'),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2,
@@ -368,7 +367,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
});
@@ -377,7 +376,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
@@ -426,7 +425,7 @@ export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType
export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zServerValidatedModelIdentifierField,
model: zModelIdentifierField,
weight: z.number().gte(-1).lte(2),
});
export type LoRA = z.infer<typeof zLoRA>;