feat(ui): use regional guidance validation utils in graph builders

This commit is contained in:
psychedelicious
2024-11-29 13:26:09 +10:00
parent 3905c97e32
commit df0c7d73f3
4 changed files with 20 additions and 36 deletions

View File

@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
CanvasRegionalGuidanceState,
IPAdapterConfig,
Rect,
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
@@ -23,19 +20,12 @@ type AddedRegionResult = {
addedIPAdapters: number;
};
const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
const isEnabled = rg.isEnabled;
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
return isEnabled && (hasTextPrompt || hasIPAdapter);
};
type AddRegionsArg = {
manager: CanvasManager;
regions: CanvasRegionalGuidanceState[];
g: Graph;
bbox: Rect;
base: BaseModelType;
model: ParameterModel;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
posCondCollect: Invocation<'collect'>;
@@ -49,7 +39,7 @@ type AddRegionsArg = {
* @param regions Array of regions to add
* @param g The graph to add the layers to
* @param bbox The bounding box
* @param base The base model type
* @param model The main model
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
@@ -63,17 +53,23 @@ export const addRegions = async ({
regions,
g,
bbox,
base,
model,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = base === 'sdxl';
const isFLUX = base === 'flux';
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';
const validRegions = regions.filter((rg) => {
if (!rg.isEnabled) {
return false;
}
return getRegionalGuidanceWarnings(rg, model).length === 0;
});
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const results: AddedRegionResult[] = [];
for (const region of validRegions) {
@@ -275,11 +271,7 @@ export const addRegions = async ({
}
}
const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
isValidIPAdapter(ipAdapter, base)
);
for (const { id, ipAdapter } of validRGIPAdapters) {
for (const { id, ipAdapter } of region.referenceImages) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
result.addedIPAdapters++;
@@ -313,11 +305,3 @@ export const addRegions = async ({
return results;
};
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
// Must be a model that matches the current base and must have a control image
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return hasModel && modelMatchesBase && hasImage;
};

View File

@@ -224,7 +224,7 @@ export const buildFLUXGraph = async (
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
base: modelConfig.base,
model: modelConfig,
posCond,
negCond: null,
posCondCollect,

View File

@@ -270,7 +270,7 @@ export const buildSD1Graph = async (
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
base: modelConfig.base,
model: modelConfig,
posCond,
negCond,
posCondCollect,

View File

@@ -275,7 +275,7 @@ export const buildSDXLGraph = async (
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
base: modelConfig.base,
model: modelConfig,
posCond,
negCond,
posCondCollect,