mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 15:57:59 -05:00
139 lines
5.5 KiB
Python
139 lines
5.5 KiB
Python
import einops
|
|
import torch
|
|
|
|
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
|
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
|
from invokeai.backend.flux.math import attention
|
|
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
|
|
|
|
|
|
class CustomDoubleStreamBlockProcessor:
|
|
"""A class containing a custom implementation of DoubleStreamBlock.forward() with additional features
|
|
(IP-Adapter, etc.).
|
|
"""
|
|
|
|
@staticmethod
|
|
def _double_stream_block_forward(
|
|
block: DoubleStreamBlock,
|
|
img: torch.Tensor,
|
|
txt: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
pe: torch.Tensor,
|
|
attn_mask: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
|
values.
|
|
"""
|
|
img_mod1, img_mod2 = block.img_mod(vec)
|
|
txt_mod1, txt_mod2 = block.txt_mod(vec)
|
|
|
|
# prepare image for attention
|
|
img_modulated = block.img_norm1(img)
|
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
img_qkv = block.img_attn.qkv(img_modulated)
|
|
img_q, img_k, img_v = einops.rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
|
img_q, img_k = block.img_attn.norm(img_q, img_k, img_v)
|
|
|
|
# prepare txt for attention
|
|
txt_modulated = block.txt_norm1(txt)
|
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
txt_qkv = block.txt_attn.qkv(txt_modulated)
|
|
txt_q, txt_k, txt_v = einops.rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
|
txt_q, txt_k = block.txt_attn.norm(txt_q, txt_k, txt_v)
|
|
|
|
# run actual attention
|
|
q = torch.cat((txt_q, img_q), dim=2)
|
|
k = torch.cat((txt_k, img_k), dim=2)
|
|
v = torch.cat((txt_v, img_v), dim=2)
|
|
|
|
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
|
|
|
# calculate the img bloks
|
|
img = img + img_mod1.gate * block.img_attn.proj(img_attn)
|
|
img = img + img_mod2.gate * block.img_mlp((1 + img_mod2.scale) * block.img_norm2(img) + img_mod2.shift)
|
|
|
|
# calculate the txt bloks
|
|
txt = txt + txt_mod1.gate * block.txt_attn.proj(txt_attn)
|
|
txt = txt + txt_mod2.gate * block.txt_mlp((1 + txt_mod2.scale) * block.txt_norm2(txt) + txt_mod2.shift)
|
|
return img, txt, img_q
|
|
|
|
@staticmethod
|
|
def custom_double_block_forward(
|
|
timestep_index: int,
|
|
total_num_timesteps: int,
|
|
block_index: int,
|
|
block: DoubleStreamBlock,
|
|
img: torch.Tensor,
|
|
txt: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
pe: torch.Tensor,
|
|
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
|
regional_prompting_extension: RegionalPromptingExtension,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
|
- IP-Adapter support
|
|
"""
|
|
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
|
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
|
block, img, txt, vec, pe, attn_mask=attn_mask
|
|
)
|
|
|
|
# Apply IP-Adapter conditioning.
|
|
for ip_adapter_extension in ip_adapter_extensions:
|
|
img = ip_adapter_extension.run_ip_adapter(
|
|
timestep_index=timestep_index,
|
|
total_num_timesteps=total_num_timesteps,
|
|
block_index=block_index,
|
|
block=block,
|
|
img_q=img_q,
|
|
img=img,
|
|
)
|
|
|
|
return img, txt
|
|
|
|
|
|
class CustomSingleStreamBlockProcessor:
|
|
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
|
|
etc.)
|
|
"""
|
|
|
|
@staticmethod
|
|
def _single_stream_block_forward(
|
|
block: SingleStreamBlock,
|
|
x: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
pe: torch.Tensor,
|
|
attn_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
"""This function is a direct copy of SingleStreamBlock.forward()."""
|
|
mod, _ = block.modulation(vec)
|
|
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
|
|
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
|
|
|
|
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
|
q, k = block.norm(q, k, v)
|
|
|
|
# compute attention
|
|
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
|
# compute activation in mlp stream, cat again and run second linear layer
|
|
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
|
|
return x + mod.gate * output
|
|
|
|
@staticmethod
|
|
def custom_single_block_forward(
|
|
timestep_index: int,
|
|
total_num_timesteps: int,
|
|
block_index: int,
|
|
block: SingleStreamBlock,
|
|
img: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
pe: torch.Tensor,
|
|
regional_prompting_extension: RegionalPromptingExtension,
|
|
) -> torch.Tensor:
|
|
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
|
- Masking
|
|
"""
|
|
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
|
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|