mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): use regional guidance validation utils in graph builders
This commit is contained in:
@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
CanvasRegionalGuidanceState,
|
||||
IPAdapterConfig,
|
||||
Rect,
|
||||
RegionalGuidanceReferenceImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
|
||||
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
@@ -23,19 +20,12 @@ type AddedRegionResult = {
|
||||
addedIPAdapters: number;
|
||||
};
|
||||
|
||||
const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
|
||||
const isEnabled = rg.isEnabled;
|
||||
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
|
||||
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
|
||||
return isEnabled && (hasTextPrompt || hasIPAdapter);
|
||||
};
|
||||
|
||||
type AddRegionsArg = {
|
||||
manager: CanvasManager;
|
||||
regions: CanvasRegionalGuidanceState[];
|
||||
g: Graph;
|
||||
bbox: Rect;
|
||||
base: BaseModelType;
|
||||
model: ParameterModel;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
|
||||
posCondCollect: Invocation<'collect'>;
|
||||
@@ -49,7 +39,7 @@ type AddRegionsArg = {
|
||||
* @param regions Array of regions to add
|
||||
* @param g The graph to add the layers to
|
||||
* @param bbox The bounding box
|
||||
* @param base The base model type
|
||||
* @param model The main model
|
||||
* @param posCond The positive conditioning node
|
||||
* @param negCond The negative conditioning node
|
||||
* @param posCondCollect The positive conditioning collector
|
||||
@@ -63,17 +53,23 @@ export const addRegions = async ({
|
||||
regions,
|
||||
g,
|
||||
bbox,
|
||||
base,
|
||||
model,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = base === 'sdxl';
|
||||
const isFLUX = base === 'flux';
|
||||
const isSDXL = model.base === 'sdxl';
|
||||
const isFLUX = model.base === 'flux';
|
||||
|
||||
const validRegions = regions.filter((rg) => {
|
||||
if (!rg.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
return getRegionalGuidanceWarnings(rg, model).length === 0;
|
||||
});
|
||||
|
||||
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
|
||||
const results: AddedRegionResult[] = [];
|
||||
|
||||
for (const region of validRegions) {
|
||||
@@ -275,11 +271,7 @@ export const addRegions = async ({
|
||||
}
|
||||
}
|
||||
|
||||
const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
|
||||
isValidIPAdapter(ipAdapter, base)
|
||||
);
|
||||
|
||||
for (const { id, ipAdapter } of validRGIPAdapters) {
|
||||
for (const { id, ipAdapter } of region.referenceImages) {
|
||||
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
|
||||
|
||||
result.addedIPAdapters++;
|
||||
@@ -313,11 +305,3 @@ export const addRegions = async ({
|
||||
|
||||
return results;
|
||||
};
|
||||
|
||||
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
|
||||
// Must be a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === base;
|
||||
const hasImage = Boolean(ipAdapter.image);
|
||||
return hasModel && modelMatchesBase && hasImage;
|
||||
};
|
||||
|
||||
@@ -224,7 +224,7 @@ export const buildFLUXGraph = async (
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
base: modelConfig.base,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond: null,
|
||||
posCondCollect,
|
||||
|
||||
@@ -270,7 +270,7 @@ export const buildSD1Graph = async (
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
base: modelConfig.base,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
|
||||
@@ -275,7 +275,7 @@ export const buildSDXLGraph = async (
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
base: modelConfig.base,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
|
||||
Reference in New Issue
Block a user