mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 10:47:58 -05:00
73 lines
3.2 KiB
Python
73 lines
3.2 KiB
Python
import torch
|
|
|
|
|
|
class RegionalIPData:
|
|
"""A class to manage the data for regional IP-Adapter conditioning."""
|
|
|
|
def __init__(
|
|
self,
|
|
image_prompt_embeds: list[torch.Tensor],
|
|
scales: list[float],
|
|
masks: list[torch.Tensor],
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
max_downscale_factor: int = 8,
|
|
):
|
|
"""Initialize a `IPAdapterConditioningData` object."""
|
|
assert len(image_prompt_embeds) == len(scales) == len(masks)
|
|
|
|
# The image prompt embeddings.
|
|
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
|
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
|
self.image_prompt_embeds = image_prompt_embeds
|
|
|
|
# The scales for the IP-Adapter attention.
|
|
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
|
self.scales = scales
|
|
|
|
# The IP-Adapter masks.
|
|
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
|
|
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
|
|
# regions and 0.0 for excluded regions.
|
|
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
|
|
|
|
def _prepare_masks(
|
|
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
|
|
) -> dict[int, torch.Tensor]:
|
|
"""Prepare the masks for the IP-Adapter attention."""
|
|
# Concatenate the masks so that they can be processed more efficiently.
|
|
mask_tensor = torch.cat(masks, dim=1)
|
|
|
|
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
|
|
|
|
masks_by_seq_len: dict[int, torch.Tensor] = {}
|
|
|
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
|
downscale_factor = 1
|
|
while downscale_factor <= max_downscale_factor:
|
|
b, num_ip_adapters, h, w = mask_tensor.shape
|
|
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
|
|
assert b == 1
|
|
|
|
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
|
|
# the spatial features.
|
|
query_seq_len = h * w
|
|
|
|
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
|
|
|
|
downscale_factor *= 2
|
|
if downscale_factor <= max_downscale_factor:
|
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
|
|
# regions to be lost entirely.
|
|
#
|
|
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
|
#
|
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
|
|
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
|
|
|
|
return masks_by_seq_len
|
|
|
|
def get_masks(self, query_seq_len: int) -> torch.Tensor:
|
|
"""Get the mask for the given query sequence length."""
|
|
return self._masks_by_seq_len[query_seq_len]
|