import { getStore } from 'app/store/nanostores/store'; import type { CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types'; import type { ControlNetConfigMetadata, IPAdapterConfigMetadata, MetadataValidateFunc, T2IAdapterConfigMetadata, } from 'features/metadata/types'; import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import type { BaseModelType } from 'services/api/types'; import { assert } from 'tsafe'; /** * Checks the given base model type against the currently-selected model's base type and throws an error if they are * incompatible. * @param base The base model type to validate. * @param message An optional message to use in the error if the base model is incompatible. */ const validateBaseCompatibility = (base?: BaseModelType, message?: string) => { if (!base) { throw new InvalidModelConfigError(message || 'Missing base'); } const currentBase = getStore().getState().params.model?.base; if (currentBase && base !== currentBase) { throw new InvalidModelConfigError(message || `Incompatible base models: ${base} and ${currentBase}`); } }; const validateRefinerModel: MetadataValidateFunc = (refinerModel) => { validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model'); return new Promise((resolve) => resolve(refinerModel)); }; const validateVAEModel: MetadataValidateFunc = (vaeModel) => { validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model'); return new Promise((resolve) => resolve(vaeModel)); }; const validateLoRA: MetadataValidateFunc = (lora) => { validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model'); return new Promise((resolve) => resolve(lora)); }; const validateLoRAs: MetadataValidateFunc = (loras) => { const validatedLoRAs: LoRA[] = []; loras.forEach((lora) => { try { validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model'); validatedLoRAs.push(lora); } catch { // This is a no-op - we want to continue validating the rest of the LoRAs, and an empty list is valid. } }); return new Promise((resolve) => resolve(validatedLoRAs)); }; const validateControlNet: MetadataValidateFunc = (controlNet) => { validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model'); return new Promise((resolve) => resolve(controlNet)); }; const validateControlNets: MetadataValidateFunc = (controlNets) => { const validatedControlNets: ControlNetConfigMetadata[] = []; controlNets.forEach((controlNet) => { try { validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model'); validatedControlNets.push(controlNet); } catch { // This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid. } }); return new Promise((resolve) => resolve(validatedControlNets)); }; const validateT2IAdapter: MetadataValidateFunc = (t2iAdapter) => { validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model'); return new Promise((resolve) => resolve(t2iAdapter)); }; const validateT2IAdapters: MetadataValidateFunc = (t2iAdapters) => { const validatedT2IAdapters: T2IAdapterConfigMetadata[] = []; t2iAdapters.forEach((t2iAdapter) => { try { validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model'); validatedT2IAdapters.push(t2iAdapter); } catch { // This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid. } }); return new Promise((resolve) => resolve(validatedT2IAdapters)); }; const validateIPAdapter: MetadataValidateFunc = (ipAdapter) => { validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model'); return new Promise((resolve) => resolve(ipAdapter)); }; const validateIPAdapters: MetadataValidateFunc = (ipAdapters) => { const validatedIPAdapters: IPAdapterConfigMetadata[] = []; ipAdapters.forEach((ipAdapter) => { try { validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model'); validatedIPAdapters.push(ipAdapter); } catch { // This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid. } }); return new Promise((resolve) => resolve(validatedIPAdapters)); }; const validateLayer: MetadataValidateFunc = async (layer) => { if (layer.type === 'control_adapter_layer') { const model = layer.controlAdapter.model; assert(model, 'Control Adapter layer missing model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); } if (layer.type === 'ip_adapter_layer') { const model = layer.ipAdapter.model; assert(model, 'IP Adapter layer missing model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); } if (layer.type === 'regional_guidance_layer') { for (const ipa of layer.ipAdapters) { const model = ipa.model; assert(model, 'IP Adapter layer missing model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); } } return layer; }; const validateLayers: MetadataValidateFunc = async (layers) => { const validatedLayers: CanvasRasterLayerState[] = []; for (const l of layers) { try { const validated = await validateLayer(l); validatedLayers.push(validated); } catch { // This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid. } } return new Promise((resolve) => resolve(validatedLayers)); }; export const validators = { refinerModel: validateRefinerModel, vaeModel: validateVAEModel, lora: validateLoRA, loras: validateLoRAs, controlNet: validateControlNet, controlNets: validateControlNets, t2iAdapter: validateT2IAdapter, t2iAdapters: validateT2IAdapters, ipAdapter: validateIPAdapter, ipAdapters: validateIPAdapters, layer: validateLayer, layers: validateLayers, } as const;