Compare commits

...

1 Commits

2 changed files with 28 additions and 12 deletions

View File

@@ -9,8 +9,12 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
# Replaced original einops.rearrange(...) call with torch.reshape(...) for slightly faster performance.
# Original call: x = rearrange(x, "B H L D -> B L (H D)")
# x = x.permute(0, 2, 1, 3) # BHLD -> BLHD
# x = x.reshape(x.shape[0], x.shape[1], -1) # BLHD -> BL(HD)
x = rearrange(x, "B H L D -> B L (H D)")
return x return x
@@ -23,6 +27,9 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
omega = 1.0 / (theta**scale) omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega) out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
# Replaced original einops.rearrange(...) call with torch.view(...) for slightly faster performance.
# Original call: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
# out = out.view(*out.shape[:-1], 2, 2)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float() return out.float()

View File

@@ -4,7 +4,6 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from invokeai.backend.flux.math import attention, rope from invokeai.backend.flux.math import attention, rope
@@ -94,13 +93,14 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim) self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor: # Unused code for reference:
qkv = self.qkv(x) # def forward(self, x: Tensor, pe: Tensor) -> Tensor:
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # qkv = self.qkv(x)
q, k = self.norm(q, k, v) # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
x = attention(q, k, v, pe=pe) # q, k = self.norm(q, k, v)
x = self.proj(x) # x = attention(q, k, v, pe=pe)
return x # x = self.proj(x)
# return x
@dataclass @dataclass
@@ -163,14 +163,22 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(
2, 0, 3, 1, 4
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(
2, 0, 3, 1, 4
)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
@@ -229,7 +237,8 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention