mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 14:05:17 -05:00
Update frontend to support regional prompts with FLUX in the canvas.
This commit is contained in:
@@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX || isSD3}
|
||||
isDisabled={isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
|
||||
@@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
|
||||
|
||||
@@ -50,14 +50,15 @@ export const addRegions = async (
|
||||
g: Graph,
|
||||
bbox: Rect,
|
||||
base: BaseModelType,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null,
|
||||
posCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'> | null,
|
||||
ipAdapterCollect: Invocation<'collect'>
|
||||
): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = base === 'sdxl';
|
||||
const isFLUX = base === 'flux';
|
||||
|
||||
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
|
||||
const results: AddedRegionResult[] = [];
|
||||
@@ -94,20 +95,27 @@ export const addRegions = async (
|
||||
if (region.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
result.addedPositivePrompt = true;
|
||||
const regionalPosCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
}
|
||||
);
|
||||
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
}
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
@@ -115,38 +123,55 @@ export const addRegions = async (
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
if (posCond.type === 'compel') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
if (region.negativePrompt) {
|
||||
result.addedNegativePrompt = true;
|
||||
assert(negCond, 'Negative conditioning node is required if there is a negative prompt');
|
||||
assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt');
|
||||
|
||||
// The main negative conditioning node
|
||||
const regionalNegCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
style: region.negativePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
}
|
||||
);
|
||||
result.addedNegativePrompt = true;
|
||||
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
style: region.negativePrompt,
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
}
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
@@ -158,17 +183,27 @@ export const addRegions = async (
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
} else if (negCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (negCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported negative conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||
if (region.autoNegative && region.positivePrompt) {
|
||||
assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting');
|
||||
|
||||
result.addedAutoNegativePositivePrompt = true;
|
||||
// We re-use the mask image, but invert it when converting to tensor
|
||||
const invertTensorMask = g.addNode({
|
||||
@@ -178,20 +213,27 @@ export const addRegions = async (
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
|
||||
const regionalPosCondInverted = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
}
|
||||
);
|
||||
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt,
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
}
|
||||
// Connect the inverted mask to the conditioning
|
||||
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
|
||||
// Connect the conditioning to the negative collector
|
||||
@@ -203,12 +245,20 @@ export const addRegions = async (
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
} else if (posCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +267,8 @@ export const addRegions = async (
|
||||
);
|
||||
|
||||
for (const { id, ipAdapter } of validRGIPAdapters) {
|
||||
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
|
||||
|
||||
result.addedIPAdapters++;
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
@@ -250,7 +302,7 @@ export const addRegions = async (
|
||||
};
|
||||
|
||||
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
// Must be a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === base;
|
||||
const hasImage = Boolean(ipAdapter.image);
|
||||
|
||||
@@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
@@ -79,7 +80,10 @@ export const buildFLUXGraph = async (
|
||||
id: getPrefixedId('flux_text_encoder'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
|
||||
const posCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('pos_cond_collect'),
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'flux_denoise',
|
||||
id: getPrefixedId('flux_denoise'),
|
||||
@@ -104,13 +108,12 @@ export const buildFLUXGraph = async (
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
|
||||
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
|
||||
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'flux');
|
||||
|
||||
@@ -216,7 +219,22 @@ export const buildFLUXGraph = async (
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
|
||||
const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
|
||||
const regionsResult = await addRegions(
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
posCond,
|
||||
null,
|
||||
posCondCollect,
|
||||
null,
|
||||
ipAdapterCollector
|
||||
);
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
|
||||
@@ -6564,6 +6564,11 @@ export type components = {
|
||||
* @description The name of conditioning tensor
|
||||
*/
|
||||
conditioning_name: string;
|
||||
/**
|
||||
* @description The mask associated with this conditioning tensor. Excluded regions should be set to False, included regions should be set to True.
|
||||
* @default null
|
||||
*/
|
||||
mask?: components["schemas"]["TensorField"] | null;
|
||||
};
|
||||
/**
|
||||
* FluxConditioningOutput
|
||||
@@ -6771,15 +6776,17 @@ export type components = {
|
||||
*/
|
||||
transformer?: components["schemas"]["TransformerField"];
|
||||
/**
|
||||
* Positive Text Conditioning
|
||||
* @description Positive conditioning tensor
|
||||
* @default null
|
||||
*/
|
||||
positive_text_conditioning?: components["schemas"]["FluxConditioningField"];
|
||||
positive_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][];
|
||||
/**
|
||||
* Negative Text Conditioning
|
||||
* @description Negative conditioning tensor. Can be None if cfg_scale is 1.0.
|
||||
* @default null
|
||||
*/
|
||||
negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null;
|
||||
negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][] | null;
|
||||
/**
|
||||
* CFG Scale
|
||||
* @description Classifier-Free Guidance scale
|
||||
@@ -7133,6 +7140,11 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* @description A mask defining the region that this conditioning prompt applies to.
|
||||
* @default null
|
||||
*/
|
||||
mask?: components["schemas"]["TensorField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_text_encoder
|
||||
|
||||
Reference in New Issue
Block a user