Files
InvokeAI/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py

73 lines
3.2 KiB
Python

import torch
class RegionalIPData:
"""A class to manage the data for regional IP-Adapter conditioning."""
def __init__(
self,
image_prompt_embeds: list[torch.Tensor],
scales: list[float],
masks: list[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
max_downscale_factor: int = 8,
):
"""Initialize a `IPAdapterConditioningData` object."""
assert len(image_prompt_embeds) == len(scales) == len(masks)
# The image prompt embeddings.
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
self.image_prompt_embeds = image_prompt_embeds
# The scales for the IP-Adapter attention.
# scales[i] contains the attention scale for the i'th IP-Adapter.
self.scales = scales
# The IP-Adapter masks.
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
# regions and 0.0 for excluded regions.
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
def _prepare_masks(
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
) -> dict[int, torch.Tensor]:
"""Prepare the masks for the IP-Adapter attention."""
# Concatenate the masks so that they can be processed more efficiently.
mask_tensor = torch.cat(masks, dim=1)
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
masks_by_seq_len: dict[int, torch.Tensor] = {}
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, num_ip_adapters, h, w = mask_tensor.shape
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
assert b == 1
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
# the spatial features.
query_seq_len = h * w
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
return masks_by_seq_len
def get_masks(self, query_seq_len: int) -> torch.Tensor:
"""Get the mask for the given query sequence length."""
return self._masks_by_seq_len[query_seq_len]