Compare commits

...

14 Commits

Author SHA1 Message Date
Ryan Dick
32f602ab2a Get naive latent space regional prompting working. Next, need to support areas in addition to masks. 2024-02-20 19:02:19 -05:00
Ryan Dick
cb6c5c23ce Fix AddConditioningMaskInvocation logic. 2024-02-20 18:28:59 -05:00
Ryan Dick
d74045d78e Create a new TextConditioningInfoWithMask type for passing conditioning info around. 2024-02-20 15:14:36 -05:00
Ryan Dick
4efd0f7fa9 Add mask_strength to ConditioningField. 2024-02-20 11:05:30 -05:00
Ryan Dick
43d5803927 Add RectangleMaskInvocation. 2024-02-20 10:36:35 -05:00
Ryan Dick
ef51005881 Remove unused code for attention map saving. 2024-02-15 17:28:55 -05:00
Ryan Dick
7b0326d7f7 Delete unused functions from shared_invokeai_diffusion.py. 2024-02-15 17:22:37 -05:00
Ryan Dick
f590b39f88 Add support for a list of ConditioningFields in DenoiseLatents. 2024-02-15 14:41:54 -05:00
Ryan Dick
58277c6ada Add a mask to the ConditioningField primitive type. 2024-02-15 13:53:32 -05:00
Ryan Dick
382fa57f3b Remove unused constructor declared with typo in name: __int__. 2024-02-14 18:18:58 -05:00
Ryan Dick
ee3abc171d Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication. 2024-02-14 18:17:46 -05:00
Ryan Dick
bf72cee555 Remove outdated comments related to T2I-Adapters and ControlNets. 2024-02-14 17:37:40 -05:00
Ryan Dick
e866e3b19f Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported. 2024-02-14 17:37:32 -05:00
Ryan Dick
16e574825c Fix avoid storing extra conditioning info in two places. 2024-02-14 15:34:15 -05:00
10 changed files with 351 additions and 863 deletions

View File

@@ -0,0 +1,92 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
InvocationContext,
WithMetadata,
invocation,
)
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
@invocation(
"add_conditioning_mask",
title="Add Conditioning Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class AddConditioningMaskInvocation(BaseInvocation):
"""Add a mask to an existing conditioning tensor."""
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
mask: ImageField = InputField(
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
)
mask_strength: float = InputField(
description="The strength of the mask to apply to the conditioning tensor.", default=1.0
)
@staticmethod
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a uint8 mask tensor."""
np_image = np.array(image)
torch_image = torch.from_numpy(np_image[:, :, 0])
mask = torch_image >= 128
return mask.to(dtype=torch.uint8)
def invoke(self, context: InvocationContext) -> ConditioningOutput:
image = context.services.images.get_pil_image(self.mask.image_name)
mask = self.convert_image_to_mask(image)
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
context.services.latents.save(mask_name, mask)
self.conditioning.mask_name = mask_name
self.conditioning.mask_strength = self.mask_strength
return ConditioningOutput(conditioning=self.conditioning)
@invocation(
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a mask image containing a rectangular mask region."""
height: int = InputField(description="The height of the image.")
width: int = InputField(description="The width of the image.")
y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).")
y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).")
x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = np.zeros((self.height, self.width, 3), dtype=np.uint8)
mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255
mask_image = Image.fromarray(mask)
image_dto = context.services.images.create(
image=mask_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -40,7 +40,11 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
IPAdapterConditioningInfo,
TextConditioningInfoWithMask,
)
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
@@ -226,7 +230,7 @@ def get_scheduler(
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
positive_conditioning: ConditioningField = InputField(
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: ConditioningField = InputField(
@@ -330,19 +334,34 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet,
seed,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
# self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here.
positive_conditioning_list = self.positive_conditioning
if not isinstance(positive_conditioning_list, list):
positive_conditioning_list = [positive_conditioning_list]
text_embeddings: list[TextConditioningInfoWithMask] = []
for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
mask_name = positive_conditioning.mask_name
mask = None
if mask_name is not None:
mask = context.services.latents.get(mask_name)
text_embeddings.append(
TextConditioningInfoWithMask(
text_conditioning_info=positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype),
mask=mask,
mask_strength=positive_conditioning.mask_strength,
)
)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
text_embeddings=text_embeddings,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,
warmup=0.2, # warmup,
@@ -767,10 +786,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
(
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,

View File

@@ -428,6 +428,16 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask_name: Optional[str] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
"regions should be set to 1.",
)
mask_strength: float = Field(
default=1.0,
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "
"other masks. The default is 1.0. If set to 0.0, this mask will be ignored.",
)
@invocation_output("conditioning_output")

View File

@@ -3,4 +3,3 @@ Initialization file for the invokeai.backend.stable_diffusion package
"""
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401

View File

@@ -12,7 +12,6 @@ import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
@@ -26,9 +25,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
@dataclass
@@ -39,7 +38,6 @@ class PipelineIntermediateState:
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
attention_map_saver: Optional[AttentionMapSaver] = None
@dataclass
@@ -184,19 +182,6 @@ class T2IAdapterData:
end_step_percent: float = Field(default=1.0)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -336,9 +321,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
) -> torch.Tensor:
if init_timestep.shape[0] == 0:
return latents, None
return latents
if additional_guidance is None:
additional_guidance = []
@@ -378,7 +363,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
try:
latents, attention_map_saver = self.generate_latents_from_embeddings(
latents = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
@@ -395,7 +380,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if mask is not None:
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
return latents, attention_map_saver
return latents
def generate_latents_from_embeddings(
self,
@@ -408,26 +393,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
batch_size = latents.shape[0]
attention_map_saver: Optional[AttentionMapSaver] = None
if timesteps.shape[0] == 0:
return latents, attention_map_saver
return latents
extra_conditioning_info = conditioning_data.text_embeddings[0].text_conditioning_info.extra_conditioning
use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None
if sum([use_cross_attention_control, use_ip_adapter]) > 1:
raise Exception("Cross-attention control and IP-Adapter cannot be used simultaneously (yet).")
ip_adapter_unet_patcher = None
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
if use_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
extra_conditioning_info=extra_conditioning_info,
)
self.use_ip_adapter = False
elif ip_adapter_data is not None:
elif use_ip_adapter:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
@@ -475,13 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
predicted_original = getattr(step_output, "pred_original_sample", None)
# TODO resuscitate attention map saving
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
if callback is not None:
callback(
PipelineIntermediateState(
@@ -491,11 +475,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
)
)
return latents, attention_map_saver
return latents
@torch.inference_mode()
def step(
@@ -537,15 +520,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0)
# Handle ControlNet(s) and T2I-Adapter(s)
# Handle ControlNet(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
down_intrablock_additional_residuals = None
# if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
# elif control_data is not None:
if control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data,
@@ -555,7 +532,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
)
# elif t2i_adapter_data is not None:
# Handle T2I-Adapter(s)
down_intrablock_additional_residuals = None
if t2i_adapter_data is not None:
accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data:
@@ -581,7 +560,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight
# down_block_additional_residuals = accum_adapter_state
down_intrablock_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
@@ -590,7 +568,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter

View File

@@ -1,6 +1 @@
"""
Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401

View File

@@ -21,11 +21,7 @@ class ExtraConditioningInfo:
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
@@ -43,6 +39,18 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
class TextConditioningInfoWithMask:
def __init__(
self,
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
mask: Optional[torch.Tensor],
mask_strength: float,
):
self.text_conditioning_info = text_conditioning_info
self.mask = mask
self.mask_strength = mask_strength
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@@ -65,8 +73,8 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
text_embeddings: list[TextConditioningInfoWithMask]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@@ -78,7 +86,6 @@ class ConditioningData:
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
@@ -87,10 +94,6 @@ class ConditioningData:
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)

View File

@@ -3,19 +3,13 @@
import enum
import math
from dataclasses import dataclass, field
from typing import Callable, Optional
from typing import Optional
import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
@@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum):
TOKENS = 2
class Context:
cross_attention_mask: Optional[torch.Tensor]
cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum):
NONE = 0
SAVE = (1,)
APPLY = 2
def __init__(self, arguments: Arguments, step_count: int):
class CrossAttnControlContext:
def __init__(self, arguments: Arguments):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.cross_attention_mask = None
self.cross_attention_index_map = None
self.self_cross_attention_action = Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.cross_attention_mask: Optional[torch.Tensor] = None
self.cross_attention_index_map: Optional[torch.Tensor] = None
self.arguments = arguments
self.step_count = step_count
self.self_cross_attention_module_identifiers = []
self.tokens_cross_attention_module_identifiers = []
self.saved_cross_attention_maps = {}
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.SAVE
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
self.tokens_cross_attention_action = Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step(
self, percent_through: float = None
@@ -111,219 +47,8 @@ class Context:
to_control.append(CrossAttentionType.TOKENS)
return to_control
def save_slice(
self,
identifier: str,
slice: torch.Tensor,
dim: Optional[int],
offset: int,
slice_size: Optional[int],
):
if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = {
"dim": dim,
"slice_size": slice_size,
"slices": {offset or 0: slice},
}
else:
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
def get_slice(
self,
identifier: str,
requested_dim: Optional[int],
requested_offset: int,
slice_size: int,
):
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict["dim"] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict["slices"][0]
if saved_attention_dict["dim"] == requested_dim:
if slice_size != saved_attention_dict["slice_size"]:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
)
return saved_attention_dict["slices"][requested_offset]
if saved_attention_dict["dim"] is None:
whole_saved_attention = saved_attention_dict["slices"][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset : requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention["dim"], saved_attention["slice_size"]
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = Context.Action.NONE
self.self_cross_attention_action = Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(
self,
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
):
"""
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current Attention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
"""
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == "cuda":
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == "mps" or q.device.type == "cpu":
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def restore_default_cross_attention(
model,
is_running_diffusers: bool,
restore_attention_processor: Optional[AttentionProcessor] = None,
):
if is_running_diffusers:
unet = model
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
else:
remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [
(name, module)
for name, module in model.named_modules()
if isinstance(module, cross_attention_class) and which_attn in name
]
cross_attention_modules_in_model_count = len(attention_module_tuples)
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
"failed or some assumption has changed about the structure of the model itself. Please fix the "
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
"attention map display will not work properly until it is fixed."
)
return attention_module_tuples
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
if context.get_should_save_maps(module.identifier):
# print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
context.save_slice(
module.identifier,
slice_to_save,
dim=dim,
offset=offset,
slice_size=slice_size,
)
elif context.get_should_apply_saved_maps(module.identifier):
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
else:
# just use everything
attention_slice = saved_attention_slice
return attention_slice
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, "name"): # Python 3.10
return error.name == attribute
else: # Python 3.9
return attribute in str(error)
def get_mem_free_total(device):
# only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
InvokeAICrossAttentionMixin.__init__(self)
def _attention(self, query, key, value, attention_mask=None):
# default_result = super()._attention(query, key, value)
if attention_mask is not None:
print(f"{type(self).__name__} ignoring passed-in attention_mask")
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
return hidden_states
## 🧨diffusers implementation follows
"""
# base implementation
class AttnProcessor:
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
@@ -533,18 +94,6 @@ class SwapCrossAttnContext:
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def __int__(
self,
cac_types_to_do: [CrossAttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor,
):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do

View File

@@ -1,100 +0,0 @@
import math
from typing import Optional
import torch
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize
class AttentionMapSaver:
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps: dict[str, torch.Tensor] = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f"{key}_{maps.shape[1]}"
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
if pil_image is not None:
pil_image.save(path, "PNG")
def get_stacked_maps_image(self) -> Optional[Image.Image]:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(
torch.swapdims(maps, 0, 1),
[num_tokens, this_maps_height, this_maps_width],
)
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(
maps_normalized_expanded_clamped,
[num_tokens * latents_height, latents_width],
)
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked) * (1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xFF).byte()
return Image.fromarray(merged_bytes.numpy(), mode="L")

View File

@@ -5,25 +5,26 @@ from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
import torchvision
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData,
ExtraConditioningInfo,
IPAdapterConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
from .cross_attention_control import (
Context,
CrossAttentionType,
CrossAttnControlContext,
SwapCrossAttnContext,
get_cross_attention_modules,
setup_cross_attention_control_attention_processors,
)
from .cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
@@ -69,14 +70,12 @@ class InvokeAIDiffuserComponent:
self,
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
):
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = Context(
self.cross_attention_control_context = CrossAttnControlContext(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
@@ -87,27 +86,6 @@ class InvokeAIDiffuserComponent:
finally:
self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_controlnet_step(
self,
@@ -116,9 +94,12 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
conditioning_data,
conditioning_data: ConditioningData,
):
down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably run
# the controlnet separately for each conditioning input.
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
@@ -149,28 +130,28 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if type(text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": text_embeddings.add_time_ids,
}
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_hidden_states = text_embeddings.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if type(text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
text_embeddings.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
text_embeddings.add_time_ids,
],
dim=0,
),
@@ -180,7 +161,7 @@ class InvokeAIDiffuserComponent:
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
text_embeddings.embeds,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
@@ -224,54 +205,102 @@ class InvokeAIDiffuserComponent:
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data, # TODO: type
conditioning_data: ConditioningData,
step_index: int,
total_step_count: int,
**kwargs,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
percent_through
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample,
timestep,
conditioning_data,
cross_attention_control_types_to_do,
**kwargs,
)
elif self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
sample,
timestep,
conditioning_data,
**kwargs,
)
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
sample,
timestep,
conditioning_data,
**kwargs,
)
cond_next_xs = []
uncond_next_x = None
for text_conditioning in conditioning_data.text_embeddings:
if wants_cross_attention_control or self.sequential_guidance:
raise NotImplementedError(
"Sequential conditioning has not yet been updated to work with multiple text embeddings."
)
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
# control is currently only supported in sequential mode.
# (
# unconditioned_next_x,
# conditioned_next_x,
# ) = self._apply_standard_conditioning_sequentially(
# x=sample,
# sigma=timestep,
# conditioning_data=conditioning_data,
# cross_attention_control_types_to_do=cross_attention_control_types_to_do,
# down_block_additional_residuals=down_block_additional_residuals,
# mid_block_additional_residual=mid_block_additional_residual,
# down_intrablock_additional_residuals=down_intrablock_additional_residuals,
# )
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x=sample,
sigma=timestep,
cond_text_embedding=text_conditioning.text_conditioning_info,
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
ip_adapter_conditioning=conditioning_data.ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
cond_next_xs.append(conditioned_next_x)
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
# once.
uncond_next_x = unconditioned_next_x
return unconditioned_next_x, conditioned_next_x
# TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
cond_out = torch.concat(cond_next_xs, dim=0)
# Initialize count to 1e-9 to avoid division by zero.
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
_, _, height, width = cond_out.shape
for te_idx, te in enumerate(conditioning_data.text_embeddings):
mask = te.mask
if mask is not None:
# Resize if necessary.
tf = torchvision.transforms.Resize(
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
# begin with?
mask = mask.to(cond_out.device, cond_out.dtype)
# Make sure that all mask values are either 0.0 or 1.0.
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
# the mask in the passed data structures.
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask *= te.mask_strength
else:
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
# values).
mask = 1.0
# Apply the mask and update the count.
cond_out[te_idx, ...] *= mask[0]
cond_count += mask[0]
# Combine the masked conditionings.
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
return uncond_next_x, cond_out
def do_latent_postprocessing(
self,
@@ -335,7 +364,17 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
def _apply_standard_conditioning(
self,
x,
sigma,
cond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
@@ -343,39 +382,39 @@ class InvokeAIDiffuserComponent:
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
for ipa_conditioning in ip_adapter_conditioning
]
}
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if type(cond_text_embedding) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
uncond_text_embedding.pooled_embeds,
cond_text_embedding.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
uncond_text_embedding.add_time_ids,
cond_text_embedding.add_time_ids,
],
dim=0,
),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
uncond_text_embedding.embeds, cond_text_embedding.embeds
)
both_results = self.model_forward_callback(
x_twice,
@@ -383,8 +422,10 @@ class InvokeAIDiffuserComponent:
both_conditionings,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x
@@ -394,14 +435,21 @@ class InvokeAIDiffuserComponent:
x: torch.Tensor,
sigma,
conditioning_data: ConditioningData,
**kwargs,
):
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
# low-memory sequential path
assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks.
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
@@ -410,7 +458,6 @@ class InvokeAIDiffuserComponent:
cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
@@ -419,12 +466,29 @@ class InvokeAIDiffuserComponent:
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# Run unconditional UNet denoising.
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
cross_attn_processor_context = None
if self.cross_attention_control_context is not None:
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
# will be populated before the conditioned pass.
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
index_map=self.cross_attention_control_context.cross_attention_index_map,
mask=self.cross_attention_control_context.cross_attention_mask,
cross_attention_types_to_do=[],
)
#####################
# Unconditioned pass
#####################
cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
@@ -434,14 +498,20 @@ class InvokeAIDiffuserComponent:
]
}
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
@@ -451,11 +521,15 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
# Run conditional UNet denoising.
###################
# Conditioned pass
###################
cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
@@ -465,105 +539,29 @@ class InvokeAIDiffuserComponent:
]
}
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": text_embeddings.add_time_ids,
}
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.text_embeddings.embeds,
text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
sigma,
conditioning_data,
cross_attention_control_types_to_do,
**kwargs,
):
context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.text_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
@@ -633,54 +631,3 @@ class InvokeAIDiffuserComponent:
self.last_percent_through = percent_through
return latents.to(device=dev)
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = (
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = math.ceil(len(conditionings) / 2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index * 2
chunk_size = min(2, len(conditionings) - offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset : offset + 2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale