Compare commits

...

34 Commits

Author SHA1 Message Date
Ryan Dick
8d04ec3f95 Improve docs related to dynamic T5 sequence length selection. 2024-11-29 16:11:51 +00:00
Ryan Dick
4581a37a48 Dynamically select smaller t5 seq len to save inference time. 2024-11-29 16:02:25 +00:00
Ryan Dick
54b7f9a063 FLUX Regional Prompting (#7388)
## Summary

This PR adds support for regional prompting with FLUX.

### Example 1
Global prompt: `An architecture rendering of the reception area of a
corporate office with modern decor.`
<img width="1386" alt="image"
src="https://github.com/user-attachments/assets/c8169bdb-49a9-44bc-bd9e-58d98e09094b">

![image](https://github.com/user-attachments/assets/4a426be9-9d7a-4527-b27c-2d2514ee73fe)

## QA Instructions

- [x] Test that there is no slowdown in the base case with a single
global prompt.
- [x] Test image fully covered by regional masks.
- [x] Test image covered by region masks with small gaps.
- [x] Test region masks with large unmasked ‘background’ regions
- [x] Test region masks with significant overlap
- [x] Test multiple global prompts.
- [x] Test no global prompt.
- [x] Test regional negative prompts (It runs... but results are not
great. Needs more tuning to be useful.)
- Test compatibility with:
    - [x] ControlNet
    - [x] LoRA
    - [x] IP-Adapter

## Remaining TODO

- [x] Disable the following UI features for FLUX prompt regions:
negative prompts, reference images, auto-negative.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2024-11-29 08:56:42 -05:00
psychedelicious
7d488a5352 feat(ui): add delete button to regional ref image empty state 2024-11-29 15:51:24 +10:00
psychedelicious
4d7667f63d fix(ui): add missing translations 2024-11-29 15:43:49 +10:00
psychedelicious
08704ee8ec feat(ui): use canvas layer validators in control/ip adapter graph builders 2024-11-29 15:32:48 +10:00
psychedelicious
5910892c33 Merge remote-tracking branch 'origin/main' into ryan/flux-regional-prompting 2024-11-29 15:19:39 +10:00
psychedelicious
46a09d9e90 feat(ui): format warnings tooltip 2024-11-29 13:32:51 +10:00
psychedelicious
df0c7d73f3 feat(ui): use regional guidance validation utils in graph builders 2024-11-29 13:26:09 +10:00
psychedelicious
3905c97e32 feat(ui): return translation keys from validation utils instead of translated strings 2024-11-29 13:25:09 +10:00
psychedelicious
0be796a808 feat(ui): use layer validation utils in invoke readiness utils 2024-11-29 13:14:26 +10:00
psychedelicious
7dd33b0f39 feat(ui): add indicator to canvas layer headers, displaying validation warnings
If there are any issues with the layer, the icon is displayed. If the layer is disabled, the icon is greyed out but still visible.
2024-11-29 13:13:47 +10:00
psychedelicious
484aaf1595 feat(ui): add canvas layer validation utils
These helpers consolidate layer validation checks. For example, checking that the layer has content drawn, is compatible with the selected main model, has valid reference images, etc.
2024-11-29 13:12:32 +10:00
psychedelicious
c276b60af9 tidy(ui): use object for addRegions graph builder util arg 2024-11-29 08:49:41 +10:00
Ryan Dick
5d8dd6e26e Fix FLUX regional negative prompts. 2024-11-28 18:49:29 +00:00
Emmanuel Ferdman
5bca68d873 docs: update code of conduct reference
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2024-11-27 17:38:33 -08:00
Ryan Dick
64364e7911 Short-circuit if there are no region masks in FLUX and don't apply attention masking. 2024-11-27 22:40:10 +00:00
Ryan Dick
6565cea039 Comment unused _prepare_unrestricted_attn_mask(...) for future reference. 2024-11-27 22:16:44 +00:00
Ryan Dick
3ebd8d6c07 Delete outdated TODO comment. 2024-11-27 22:13:25 +00:00
Ryan Dick
e970185161 Tweak flux regional prompting attention scheme based on latest experimentation. 2024-11-27 22:13:07 +00:00
Ryan Dick
fa5653cdf7 Remove unused 'denoise' param to addRegions(). 2024-11-27 17:08:42 +00:00
Ryan Dick
9a7b000995 Update frontend to support regional prompts with FLUX in the canvas. 2024-11-27 17:04:43 +00:00
Ryan Dick
3a27242838 Bump transformers. The main motivation for this bump is to ingest a fix for DepthAnything postprocessing artifacts. 2024-11-27 07:46:16 -08:00
Ryan Dick
b54463d294 Allow regional prompting background regions to attend to themselves and to the entire txt embedding. 2024-11-26 17:57:31 +00:00
Ryan Dick
faee79dc95 Distinguish between restricted and unrestricted attn masks in FLUX regional prompting. 2024-11-26 16:55:52 +00:00
Ryan Dick
e01f66b026 Apply regional attention masks in the single stream blocks in addition to the double stream blocks. 2024-11-25 22:40:08 +00:00
Ryan Dick
53abdde242 Update Flux RegionalPromptingExtension to prepare both a mask with restricted image self-attention and a mask with unrestricted image self attention. 2024-11-25 22:04:23 +00:00
Ryan Dick
94c088300f Be smarter about selecting the global CLIP embedding for FLUX regional prompting. 2024-11-25 20:15:04 +00:00
Ryan Dick
3741a6f5e0 Fix device handling for regional masks and apply the attention mask in the FLUX double stream block. 2024-11-25 16:02:03 +00:00
Ryan Dick
2c23b8414c Use a single global CLIP embedding for FLUX regional guidance. 2024-11-22 23:01:43 +00:00
Ryan Dick
20356c0746 Fixup the logic for preparing FLUX regional prompt attention masks. 2024-11-21 22:46:25 +00:00
Ryan Dick
bad1149504 WIP - add rough logic for preparing the FLUX regional prompting attention mask. 2024-11-20 22:29:36 +00:00
Ryan Dick
fda7aaa7ca Pass RegionalPromptingExtension down to the CustomDoubleStreamBlockProcessor in FLUX. 2024-11-20 19:48:04 +00:00
Ryan Dick
85c616fa34 WIP - Pass prompt masks to FLUX model during denoising. 2024-11-20 18:51:43 +00:00
28 changed files with 1158 additions and 381 deletions

View File

@@ -38,7 +38,7 @@ This project is a combined effort of dedicated people from across the world. [C
## Code of Conduct
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/docs/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
By making a contribution to this project, you certify that:

View File

@@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
class SD3ConditioningField(BaseModel):

View File

@@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
@@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
@@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: FluxConditioningField | None = InputField(
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
@@ -139,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
return t5_embeddings, clip_embeddings
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings: torch.Tensor | None = None
neg_clip_embeddings: torch.Tensor | None = None
if self.negative_text_conditioning is not None:
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
@@ -183,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
seed=self.seed,
)
b, _c, latent_h, latent_w = noise.shape
packed_h = latent_h // 2
packed_w = latent_w // 2
# Load the conditioning data.
pos_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.positive_text_conditioning,
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
neg_text_conditionings: list[FluxTextConditioning] | None = None
if self.negative_text_conditioning is not None:
neg_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.negative_text_conditioning,
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
pos_text_conditionings, img_seq_len=packed_h * packed_w
)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
if neg_text_conditionings
else None
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
image_seq_len=packed_h * packed_w,
shift=not is_schnell,
)
@@ -228,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
pos_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_txt_ids: torch.Tensor | None = None
if neg_t5_embeddings is not None:
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
# packed_w correctly.
assert packed_h * packed_w == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
@@ -338,12 +335,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
model=transformer,
img=x,
img_ids=img_ids,
txt=pos_t5_embeddings,
txt_ids=pos_txt_ids,
vec=pos_clip_embeddings,
neg_txt=neg_t5_embeddings,
neg_txt_ids=neg_txt_ids,
neg_vec=neg_clip_embeddings,
pos_regional_prompting_extension=pos_regional_prompting_extension,
neg_regional_prompting_extension=neg_regional_prompting_extension,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
@@ -357,6 +350,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width)
return x
def _load_text_conditioning(
self,
context: InvocationContext,
cond_field: FluxConditioningField | list[FluxConditioningField],
packed_height: int,
packed_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
text_conditionings: list[FluxTextConditioning] = []
for cond_field in cond_list:
# Load the text embeddings.
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = mask.to(device=device)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, packed_height, packed_width, dtype, device
)
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
return text_conditionings
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int

View File

@@ -1,11 +1,18 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Tuple
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -22,7 +29,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
@@ -41,9 +48,14 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(
description="Text prompt to encode.",
ui_component=UIComponent.Textarea,
use_short_t5_seq_len: bool = InputField(
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
+ "performance and reduce peak memory, but may result in slightly different image outputs.",
default=True,
)
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
@@ -57,7 +69,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
@@ -65,6 +79,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt = [self.prompt]
valid_seq_lens = [self.t5_max_seq_len]
if self.use_short_t5_seq_len:
# We allow a minimum sequence length of 128. Going too short results in more significant image chagnes.
valid_seq_lens = list(range(128, self.t5_max_seq_len, 128))
valid_seq_lens.append(self.t5_max_seq_len)
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
@@ -72,10 +92,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
prompt_embeds = t5_encoder(prompt, valid_seq_lens)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
@@ -113,10 +133,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
pooled_prompt_embeds = clip_encoder(prompt, [77])
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

View File

@@ -1,9 +1,10 @@
import einops
import torch
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.math import attention
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
class CustomDoubleStreamBlockProcessor:
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
@staticmethod
def _double_stream_block_forward(
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
block: DoubleStreamBlock,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
values.
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
vec: torch.Tensor,
pe: torch.Tensor,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support
"""
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
block, img, txt, vec, pe, attn_mask=attn_mask
)
# Apply IP-Adapter conditioning.
for ip_adapter_extension in ip_adapter_extensions:
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
)
return img, txt
class CustomSingleStreamBlockProcessor:
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
etc.)
"""
@staticmethod
def _single_stream_block_forward(
block: SingleStreamBlock,
x: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""This function is a direct copy of SingleStreamBlock.forward()."""
mod, _ = block.modulation(vec)
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
q, k = block.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
return x + mod.gate * output
@staticmethod
def custom_single_block_forward(
timestep_index: int,
total_num_timesteps: int,
block_index: int,
block: SingleStreamBlock,
img: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
regional_prompting_extension: RegionalPromptingExtension,
) -> torch.Tensor:
"""A custom implementation of SingleStreamBlock.forward() with additional features:
- Masking
"""
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)

View File

@@ -7,6 +7,7 @@ from tqdm import tqdm
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
@@ -18,14 +19,8 @@ def denoise(
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
# positive text conditioning
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# negative text conditioning
neg_txt: torch.Tensor | None,
neg_txt_ids: torch.Tensor | None,
neg_vec: torch.Tensor | None,
pos_regional_prompting_extension: RegionalPromptingExtension,
neg_regional_prompting_extension: RegionalPromptingExtension | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -61,9 +56,9 @@ def denoise(
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
@@ -78,9 +73,9 @@ def denoise(
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -88,6 +83,7 @@ def denoise(
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
step_cfg_scale = cfg_scale[step_index]
@@ -97,15 +93,15 @@ def denoise(
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -113,6 +109,7 @@ def denoise(
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
pred = neg_pred + step_cfg_scale * (pred - neg_pred)

View File

@@ -0,0 +1,276 @@
from typing import Optional
import torch
import torchvision
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
class RegionalPromptingExtension:
"""A class for managing regional prompting with FLUX.
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
"""
def __init__(
self,
regional_text_conditioning: FluxRegionalTextConditioning,
restricted_attn_mask: torch.Tensor | None = None,
):
self.regional_text_conditioning = regional_text_conditioning
self.restricted_attn_mask = restricted_attn_mask
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
order = [self.restricted_attn_mask, None]
return order[block_index % len(order)]
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
order = [self.restricted_attn_mask, None]
return order[block_index % len(order)]
@classmethod
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
"""Create a RegionalPromptingExtension from a list of text conditionings.
Args:
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
"""
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
return cls(
regional_text_conditioning=regional_text_conditioning,
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
)
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
#
# @classmethod
# def _prepare_unrestricted_attn_mask(
# cls,
# regional_text_conditioning: FluxRegionalTextConditioning,
# img_seq_len: int,
# ) -> torch.Tensor:
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
# - img self-attention is not masked.
# - img regions attend to both txt within their own region and to global prompts.
# """
# device = TorchDevice.choose_torch_device()
# # Infer txt_seq_len from the t5_embeddings tensor.
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# # Concatenation happens in the following order: [txt_seq, img_seq].
# # There are 4 portions of the attention mask to consider as we prepare it:
# # 1. txt attends to itself
# # 2. txt attends to corresponding regional img
# # 3. regional img attends to corresponding txt
# # 4. regional img attends to itself
# # Initialize empty attention mask.
# regional_attention_mask = torch.zeros(
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
# )
# for image_mask, t5_embedding_range in zip(
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
# ):
# # 1. txt attends to itself
# regional_attention_mask[
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
# ] = 1.0
# # 2. txt attends to corresponding regional img
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
# # 3. regional img attends to corresponding txt
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
# # 4. regional img attends to itself
# # Allow unrestricted img self attention.
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
# # Convert attention mask to boolean.
# regional_attention_mask = regional_attention_mask > 0.5
# return regional_attention_mask
@classmethod
def _prepare_restricted_attn_mask(
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor | None:
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
- img self-attention is only allowed within regions.
- img regions only attend to txt within their own region, not to global prompts.
"""
# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask
if background_region_mask is None:
# There are no region masks, short-circuit and return None.
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
# is a rare use case and would make the logic here significantly more complicated.
return None
device = TorchDevice.choose_torch_device()
# Infer txt_seq_len from the t5_embeddings tensor.
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# Concatenation happens in the following order: [txt_seq, img_seq].
# There are 4 portions of the attention mask to consider as we prepare it:
# 1. txt attends to itself
# 2. txt attends to corresponding regional img
# 3. regional img attends to corresponding txt
# 4. regional img attends to itself
# Initialize empty attention mask.
regional_attention_mask = torch.zeros(
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
)
for image_mask, t5_embedding_range in zip(
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
):
# 1. txt attends to itself
regional_attention_mask[
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
] = 1.0
if image_mask is not None:
# 2. txt attends to corresponding regional img
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
image_mask.view(1, img_seq_len)
)
# 3. regional img attends to corresponding txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
image_mask.view(img_seq_len, 1)
)
# 4. regional img attends to itself
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
else:
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
# that regions focus on their local prompts. We do, however, allow attention between background regions
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
# embeddings, which we found experimentally to cause artifacts.
# 2. global txt attends to background region
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
background_region_mask.view(1, img_seq_len)
)
# 3. background region attends to global txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
background_region_mask.view(img_seq_len, 1)
)
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5
return regional_attention_mask
@classmethod
def _concat_regional_text_conditioning(
cls,
text_conditionings: list[FluxTextConditioning],
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
image_masks: list[torch.Tensor | None] = []
# Choose global CLIP embedding.
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
# the first prompt's CLIP embedding.
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
for text_conditioning in text_conditionings:
if text_conditioning.mask is None:
global_clip_embedding = text_conditioning.clip_embeddings
break
cur_t5_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
# Initialize the txt_ids tensor.
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
t5_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
)
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
clip_embeddings=global_clip_embedding,
t5_txt_ids=t5_txt_ids,
image_masks=image_masks,
t5_embedding_ranges=concat_t5_embedding_ranges,
)
@staticmethod
def preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
"""
if mask is None:
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
# Flatten the height and width dimensions into a single image_seq_len dimension.
return resized_mask.flatten(start_dim=2)

View File

@@ -5,10 +5,10 @@ from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x

View File

@@ -5,7 +5,11 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
from invokeai.backend.flux.custom_block_processor import (
CustomDoubleStreamBlockProcessor,
CustomSingleStreamBlockProcessor,
)
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
@@ -95,6 +99,7 @@ class Flux(nn.Module):
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -117,7 +122,6 @@ class Flux(nn.Module):
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
assert isinstance(block, DoubleStreamBlock)
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
@@ -128,6 +132,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
ip_adapter_extensions=ip_adapter_extensions,
regional_prompting_extension=regional_prompting_extension,
)
if controlnet_double_block_residuals is not None:
@@ -140,7 +145,17 @@ class Flux(nn.Module):
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
assert isinstance(block, SingleStreamBlock)
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
block_index=block_index,
block=block,
img=img,
vec=vec,
pe=pe,
regional_prompting_extension=regional_prompting_extension,
)
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]

View File

@@ -1,32 +1,53 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
"""Encode text into a tensor.
Args:
text: A list of text prompts to encode.
valid_seq_lens: A list of valid sequence lengths. The shortest valid sequence length that can contain the
text will be used. If the largest valid sequence length cannot contain the text, the encoding will be
truncated.
"""
valid_seq_lens = sorted(valid_seq_lens)
# Perform initial encoding with the maximum valid sequence length.
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
max_length=max(valid_seq_lens),
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
seq_len: int = batch_encoding["length"][0].item()
selected_seq_len = valid_seq_lens[-1]
for len in valid_seq_lens:
if len >= seq_len:
selected_seq_len = len
break
input_ids = batch_encoding["input_ids"][..., :selected_seq_len]
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=input_ids.to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)

View File

@@ -0,0 +1,36 @@
from dataclasses import dataclass
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
@dataclass
class FluxTextConditioning:
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
# If mask is None, the prompt is a global prompt.
mask: torch.Tensor | None
@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.
# Shape: (1, concatenated_txt_seq_len, 4096)
t5_embeddings: torch.Tensor
# Shape: (1, concatenated_txt_seq_len, 3)
t5_txt_ids: torch.Tensor
# Global CLIP embeddings.
# Shape: (1, 768)
clip_embeddings: torch.Tensor
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
# global prompt.
# image_masks[i] is the mask for the ith prompt.
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
image_masks: list[torch.Tensor | None]
# List of ranges that represent the embedding ranges for each mask.
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
t5_embedding_ranges: list[Range]

View File

@@ -176,7 +176,8 @@
"reset": "Reset",
"none": "None",
"new": "New",
"generating": "Generating"
"generating": "Generating",
"warnings": "Warnings"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -1040,6 +1041,7 @@
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
"layer": {
"unsupportedModel": "layer not supported for selected base model",
"controlAdapterNoModelSelected": "no Control Adapter model selected",
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}",
@@ -1050,7 +1052,10 @@
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
"rgNoRegion": "no region selected"
"rgNegativePromptNotSupported": "negative prompt not supported for selected base model",
"rgReferenceImagesNotSupported": "regional reference images not supported for selected base model",
"rgAutoNegativeNotSupported": "auto-negative not supported for selected base model",
"emptyLayer": "empty layer"
}
},
"maskBlur": "Mask Blur",

View File

@@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX || isSD3}
isDisabled={isSD3}
>
{t('controlLayers.regionalGuidance')}
</Button>

View File

@@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>

View File

@@ -1,27 +1,28 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
import { PiXBold } from 'react-icons/pi';
type Props = {
type Props = Omit<IconButtonProps, 'aria-label'> & {
onDelete: () => void;
};
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => {
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => {
const { t } = useTranslation();
return (
<Tooltip label={t('controlLayers.deletePrompt')}>
<IconButton
variant="link"
aria-label={t('controlLayers.deletePrompt')}
icon={<PiTrashSimpleFill />}
onClick={onDelete}
flexGrow={0}
size="sm"
p={0}
colorScheme="error"
/>
</Tooltip>
<IconButton
tooltip={t('common.delete')}
variant="link"
aria-label={t('common.delete')}
icon={<PiXBold />}
onClick={onDelete}
flexGrow={0}
size="sm"
p={0}
colorScheme="error"
{...rest}
/>
);
});

View File

@@ -1,8 +1,10 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { RegionalGuidanceDeletePromptButton } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice';
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -31,6 +33,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const onDeleteIPAdapter = useCallback(() => {
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
}, [dispatch, entityIdentifier, referenceImageId]);
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
() =>
@@ -43,6 +48,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
return (
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
<RegionalGuidanceDeletePromptButton onDelete={onDeleteIPAdapter} position="absolute" top={0} insetInlineEnd={0} />
<Text textAlign="center" color="base.300">
<Trans
i18nKey="controlLayers.referenceImageEmptyState"

View File

@@ -1,6 +1,7 @@
import { Flex } from '@invoke-ai/ui-library';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeaderWarnings } from 'features/controlLayers/components/common/CanvasEntityHeaderWarnings';
import { CanvasEntityIsBookmarkedForQuickSwitchToggle } from 'features/controlLayers/components/common/CanvasEntityIsBookmarkedForQuickSwitchToggle';
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -11,6 +12,7 @@ export const CanvasEntityHeaderCommonActions = memo(() => {
return (
<Flex alignSelf="stretch">
<CanvasEntityHeaderWarnings />
<CanvasEntityIsBookmarkedForQuickSwitchToggle />
{entityIdentifier.type !== 'reference_image' && <CanvasEntityIsLockedToggle />}
<CanvasEntityEnabledToggle />

View File

@@ -0,0 +1,101 @@
import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import {
getControlLayerWarnings,
getGlobalReferenceImageWarnings,
getInpaintMaskWarnings,
getRasterLayerWarnings,
getRegionalGuidanceWarnings,
} from 'features/controlLayers/store/validators';
import type { TFunction } from 'i18next';
import { upperFirst } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
// Should never throw.
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
let warnings: string[] = [];
const entityType = entity.type;
if (entityType === 'control_layer') {
warnings = getControlLayerWarnings(entity, model);
} else if (entityType === 'regional_guidance') {
warnings = getRegionalGuidanceWarnings(entity, model);
} else if (entityType === 'inpaint_mask') {
warnings = getInpaintMaskWarnings(entity, model);
} else if (entityType === 'raster_layer') {
warnings = getRasterLayerWarnings(entity, model);
} else if (entityType === 'reference_image') {
warnings = getGlobalReferenceImageWarnings(entity, model);
} else {
assert<Equals<typeof entityType, never>>(false, 'Unexpected entity type');
}
// Return a stable reference if there are no warnings
if (warnings.length === 0) {
return EMPTY_ARRAY;
}
return warnings.map((w) => t(w)).map(upperFirst);
});
};
export const CanvasEntityHeaderWarnings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const { t } = useTranslation();
const isEnabled = useEntityIsEnabled(entityIdentifier);
const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]);
const warnings = useAppSelector(selectWarnings);
if (warnings.length === 0) {
return null;
}
return (
// Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but
// it's not a button
<IconButton
as="span"
size="sm"
variant="link"
alignSelf="stretch"
aria-label="warnings"
tooltip={<TooltipContent warnings={warnings} />}
icon={<PiWarningBold />}
colorScheme="warning"
isDisabled={!isEnabled}
/>
);
});
CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings';
const TooltipContent = memo((props: { warnings: string[] }) => {
const { t } = useTranslation();
return (
<Flex flexDir="column">
<Text>{t('common.warnings')}:</Text>
<UnorderedList>
{props.warnings.map((warning, index) => (
<ListItem key={index}>{warning}</ListItem>
))}
</UnorderedList>
</Flex>
);
});
TooltipContent.displayName = 'TooltipContent';

View File

@@ -0,0 +1,133 @@
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
export const getRegionalGuidanceWarnings = (
entity: CanvasRegionalGuidanceState,
model: ParameterModel | null
): string[] => {
const warnings: string[] = [];
if (entity.objects.length === 0) {
// Layer is in empty state - skip other checks
warnings.push('parameters.invoke.layer.emptyLayer');
} else {
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
// Must have at least 1 prompt or IP Adapter
warnings.push('parameters.invoke.layer.rgNoPromptsOrIPAdapters');
}
if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push('parameters.invoke.layer.unsupportedModel');
} else if (model.base === 'flux') {
// Some features are not supported for flux models
if (entity.negativePrompt !== null) {
warnings.push('parameters.invoke.layer.rgNegativePromptNotSupported');
}
if (entity.referenceImages.length > 0) {
warnings.push('parameters.invoke.layer.rgReferenceImagesNotSupported');
}
if (entity.autoNegative) {
warnings.push('parameters.invoke.layer.rgAutoNegativeNotSupported');
}
} else {
entity.referenceImages.forEach(({ ipAdapter }) => {
if (!ipAdapter.model) {
// No model selected
warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected');
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel');
}
if (!ipAdapter.image) {
// No image selected
warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected');
}
});
}
}
}
return warnings;
};
export const getGlobalReferenceImageWarnings = (
entity: CanvasReferenceImageState,
model: ParameterModel | null
): string[] => {
const warnings: string[] = [];
if (!entity.ipAdapter.model) {
// No model selected
warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected');
} else if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push('parameters.invoke.layer.unsupportedModel');
} else if (entity.ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel');
}
}
if (!entity.ipAdapter.image) {
// No image selected
warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected');
}
return warnings;
};
export const getControlLayerWarnings = (entity: CanvasControlLayerState, model: ParameterModel | null): string[] => {
const warnings: string[] = [];
if (entity.objects.length === 0) {
// Layer is in empty state - skip other checks
warnings.push('parameters.invoke.layer.emptyLayer');
} else {
if (!entity.controlAdapter.model) {
// No model selected
warnings.push('parameters.invoke.layer.controlAdapterNoModelSelected');
} else if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push('parameters.invoke.layer.unsupportedModel');
} else if (entity.controlAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push('parameters.invoke.layer.controlAdapterIncompatibleBaseModel');
}
}
}
return warnings;
};
export const getRasterLayerWarnings = (entity: CanvasRasterLayerState, _model: ParameterModel | null): string[] => {
const warnings: string[] = [];
if (entity.objects.length === 0) {
// Layer is in empty state - skip other checks
warnings.push('parameters.invoke.layer.emptyLayer');
}
return warnings;
};
export const getInpaintMaskWarnings = (entity: CanvasInpaintMaskState, _model: ParameterModel | null): string[] => {
const warnings: string[] = [];
if (entity.objects.length === 0) {
// Layer is in empty state - skip other checks
warnings.push('parameters.invoke.layer.emptyLayer');
}
return warnings;
};

View File

@@ -1,35 +1,41 @@
import { logger } from 'app/logging/logger';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
CanvasControlLayerState,
ControlNetConfig,
Rect,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import type { ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
type AddControlNetsArg = {
manager: CanvasManager;
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
};
type AddControlNetsResult = {
addedControlNets: number;
};
export const addControlNets = async (
manager: CanvasManager,
layers: CanvasControlLayerState[],
g: Graph,
rect: Rect,
collector: Invocation<'collect'>,
base: BaseModelType
): Promise<AddControlNetsResult> => {
const validControlLayers = layers
.filter((layer) => layer.isEnabled)
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
.filter((layer) => layer.controlAdapter.type === 'controlnet');
export const addControlNets = async ({
manager,
entities,
g,
rect,
collector,
model,
}: AddControlNetsArg): Promise<AddControlNetsResult> => {
const validControlLayers = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => entity.controlAdapter.type === 'controlnet')
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddControlNetsResult = {
addedControlNets: 0,
@@ -54,22 +60,31 @@ export const addControlNets = async (
return result;
};
type AddT2IAdaptersArg = {
manager: CanvasManager;
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
};
type AddT2IAdaptersResult = {
addedT2IAdapters: number;
};
export const addT2IAdapters = async (
manager: CanvasManager,
layers: CanvasControlLayerState[],
g: Graph,
rect: Rect,
collector: Invocation<'collect'>,
base: BaseModelType
): Promise<AddT2IAdaptersResult> => {
const validControlLayers = layers
.filter((layer) => layer.isEnabled)
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
.filter((layer) => layer.controlAdapter.type === 't2i_adapter');
export const addT2IAdapters = async ({
manager,
entities,
g,
rect,
collector,
model,
}: AddT2IAdaptersArg): Promise<AddT2IAdaptersResult> => {
const validControlLayers = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => entity.controlAdapter.type === 't2i_adapter')
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddT2IAdaptersResult = {
addedT2IAdapters: 0,
@@ -145,11 +160,3 @@ const addT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item');
};
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model
const hasModel = Boolean(controlAdapter.model);
// Model must match the current base model
const modelMatchesBase = controlAdapter.model?.base === base;
return hasModel && modelMatchesBase;
};

View File

@@ -1,19 +1,23 @@
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
type AddIPAdaptersResult = {
addedIPAdapters: number;
};
export const addIPAdapters = (
ipAdapters: CanvasReferenceImageState[],
g: Graph,
collector: Invocation<'collect'>,
base: BaseModelType
): AddIPAdaptersResult => {
const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base));
type AddIPAdaptersArg = {
entities: CanvasReferenceImageState[];
g: Graph;
collector: Invocation<'collect'>;
model: ParameterModel;
};
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddIPAdaptersResult = {
addedIPAdapters: 0,
@@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In
g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item');
};
const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return isEnabled && hasModel && modelMatchesBase && hasImage;
};

View File

@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
CanvasRegionalGuidanceState,
IPAdapterConfig,
Rect,
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
@@ -23,19 +20,26 @@ type AddedRegionResult = {
addedIPAdapters: number;
};
const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
const isEnabled = rg.isEnabled;
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
return isEnabled && (hasTextPrompt || hasIPAdapter);
type AddRegionsArg = {
manager: CanvasManager;
regions: CanvasRegionalGuidanceState[];
g: Graph;
bbox: Rect;
model: ParameterModel;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
posCondCollect: Invocation<'collect'>;
negCondCollect: Invocation<'collect'> | null;
ipAdapterCollect: Invocation<'collect'>;
};
/**
* Adds regional guidance to the graph
* @param manager The canvas manager
* @param regions Array of regions to add
* @param g The graph to add the layers to
* @param base The base model type
* @param denoise The main denoise node
* @param bbox The bounding box
* @param model The main model
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
@@ -44,22 +48,28 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
* @returns A promise that resolves to the regions that were successfully added to the graph
*/
export const addRegions = async (
manager: CanvasManager,
regions: CanvasRegionalGuidanceState[],
g: Graph,
bbox: Rect,
base: BaseModelType,
denoise: Invocation<'denoise_latents'>,
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'>,
ipAdapterCollect: Invocation<'collect'>
): Promise<AddedRegionResult[]> => {
const isSDXL = base === 'sdxl';
export const addRegions = async ({
manager,
regions,
g,
bbox,
model,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';
const validRegions = regions.filter((rg) => {
if (!rg.isEnabled) {
return false;
}
return getRegionalGuidanceWarnings(rg, model).length === 0;
});
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const results: AddedRegionResult[] = [];
for (const region of validRegions) {
@@ -94,20 +104,27 @@ export const addRegions = async (
if (region.positivePrompt) {
// The main positive conditioning node
result.addedPositivePrompt = true;
const regionalPosCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
}
);
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
if (isSDXL) {
regionalPosCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
});
} else if (isFLUX) {
regionalPosCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
});
} else {
regionalPosCond = g.addNode({
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
});
}
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
// Connect the conditioning to the collector
@@ -115,38 +132,55 @@ export const addRegions = async (
// Copy the connections to the "global" positive conditioning node to the regional cond
if (posCond.type === 'compel') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'flux_text_encoder') {
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
assert(false, 'Unsupported positive conditioning node type.');
}
}
if (region.negativePrompt) {
result.addedNegativePrompt = true;
assert(negCond, 'Negative conditioning node is required if there is a negative prompt');
assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt');
// The main negative conditioning node
const regionalNegCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
style: region.negativePrompt,
}
: {
type: 'compel',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
}
);
result.addedNegativePrompt = true;
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
if (isSDXL) {
regionalNegCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
style: region.negativePrompt,
});
} else if (isFLUX) {
regionalNegCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
});
} else {
regionalNegCond = g.addNode({
type: 'compel',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
});
}
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
// Connect the conditioning to the collector
@@ -158,17 +192,27 @@ export const addRegions = async (
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else {
} else if (negCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else if (negCond.type === 'flux_text_encoder') {
for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported negative conditioning node type.');
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (region.autoNegative && region.positivePrompt) {
assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting');
result.addedAutoNegativePositivePrompt = true;
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({
@@ -178,20 +222,27 @@ export const addRegions = async (
// Connect the OG mask image to the inverted mask-to-tensor node
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
const regionalPosCondInverted = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
style: region.positivePrompt,
}
: {
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
}
);
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
if (isSDXL) {
regionalPosCondInverted = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
style: region.positivePrompt,
});
} else if (isFLUX) {
regionalPosCondInverted = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
});
} else {
regionalPosCondInverted = g.addNode({
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
});
}
// Connect the inverted mask to the conditioning
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
// Connect the conditioning to the negative collector
@@ -203,20 +254,26 @@ export const addRegions = async (
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else {
} else if (posCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'flux_text_encoder') {
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported positive conditioning node type.');
}
}
const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
isValidIPAdapter(ipAdapter, base)
);
for (const { id, ipAdapter } of region.referenceImages) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
for (const { id, ipAdapter } of validRGIPAdapters) {
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
@@ -248,11 +305,3 @@ export const addRegions = async (
return results;
};
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return hasModel && modelMatchesBase && hasImage;
};

View File

@@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -79,7 +80,10 @@ export const buildFLUXGraph = async (
id: getPrefixedId('flux_text_encoder'),
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('pos_cond_collect'),
});
const denoise = g.addNode({
type: 'flux_denoise',
id: getPrefixedId('flux_denoise'),
@@ -104,13 +108,12 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
g.addEdge(denoise, 'latents', l2i, 'latents');
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'flux');
@@ -196,31 +199,50 @@ export const buildFLUXGraph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
const controlNetResult = await addControlNets({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: controlNetCollector,
model: modelConfig,
});
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
g,
collector: ipAdapterCollect,
model: modelConfig,
});
const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model: modelConfig,
posCond,
negCond: null,
posCondCollect,
negCondCollect: null,
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@@ -227,14 +227,14 @@ export const buildSD1Graph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
const controlNetResult = await addControlNets({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: controlNetCollector,
model: modelConfig,
});
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
@@ -245,46 +245,50 @@ export const buildSD1Graph = async (
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters(
const t2iAdapterResult = await addT2IAdapters({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
t2iAdapterCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: t2iAdapterCollector,
model: modelConfig,
});
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const regionsResult = await addRegions(
manager,
canvas.regionalGuidance.entities,
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
g,
canvas.bbox.rect,
modelConfig.base,
denoise,
collector: ipAdapterCollect,
model: modelConfig,
});
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model: modelConfig,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollector
);
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@@ -232,14 +232,14 @@ export const buildSDXLGraph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
const controlNetResult = await addControlNets({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: controlNetCollector,
model: modelConfig,
});
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
@@ -250,46 +250,50 @@ export const buildSDXLGraph = async (
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters(
const t2iAdapterResult = await addT2IAdapters({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
t2iAdapterCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: t2iAdapterCollector,
model: modelConfig,
});
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const regionsResult = await addRegions(
manager,
canvas.regionalGuidance.entities,
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
g,
canvas.bbox.rect,
modelConfig.base,
denoise,
collector: ipAdapterCollect,
model: modelConfig,
});
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model: modelConfig,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollector
);
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@@ -4,6 +4,13 @@ import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasState } from 'features/controlLayers/store/types';
import {
getControlLayerWarnings,
getGlobalReferenceImageWarnings,
getInpaintMaskWarnings,
getRasterLayerWarnings,
getRegionalGuidanceWarnings,
} from 'features/controlLayers/store/validators';
import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
@@ -278,17 +285,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
// Must have model
if (!controlLayer.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (controlLayer.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
const problems = getControlLayerWarnings(controlLayer, model);
if (problems.length) {
const content = upperFirst(problems.join(', '));
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
@@ -300,23 +300,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
// Must have model
if (!entity.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (entity.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!entity.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
const problems = getGlobalReferenceImageWarnings(entity, model);
if (problems.length) {
const content = upperFirst(problems.join(', '));
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
@@ -328,32 +315,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
// Must have a region
if (entity.objects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
entity.referenceImages.forEach(({ ipAdapter }) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
const problems = getRegionalGuidanceWarnings(entity, model);
if (problems.length) {
const content = upperFirst(problems.join(', '));
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
@@ -365,10 +330,25 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
const problems = getRasterLayerWarnings(entity, model);
if (problems.length) {
const content = upperFirst(problems.join(', '));
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
canvas.inpaintMasks.entities
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layer_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems = getInpaintMaskWarnings(entity, model);
if (problems.length) {
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});

View File

@@ -6564,6 +6564,11 @@ export type components = {
* @description The name of conditioning tensor
*/
conditioning_name: string;
/**
* @description The mask associated with this conditioning tensor. Excluded regions should be set to False, included regions should be set to True.
* @default null
*/
mask?: components["schemas"]["TensorField"] | null;
};
/**
* FluxConditioningOutput
@@ -6771,15 +6776,17 @@ export type components = {
*/
transformer?: components["schemas"]["TransformerField"];
/**
* Positive Text Conditioning
* @description Positive conditioning tensor
* @default null
*/
positive_text_conditioning?: components["schemas"]["FluxConditioningField"];
positive_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][];
/**
* Negative Text Conditioning
* @description Negative conditioning tensor. Can be None if cfg_scale is 1.0.
* @default null
*/
negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null;
negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][] | null;
/**
* CFG Scale
* @description Classifier-Free Guidance scale
@@ -7133,6 +7140,11 @@ export type components = {
* @default null
*/
prompt?: string;
/**
* @description A mask defining the region that this conditioning prompt applies to.
* @default null
*/
mask?: components["schemas"]["TensorField"] | null;
/**
* type
* @default flux_text_encoder

View File

@@ -56,7 +56,7 @@ dependencies = [
"torchmetrics",
"torchsde",
"torchvision",
"transformers==4.41.1",
"transformers==4.46.3",
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.1",