mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 17:18:11 -05:00
Compare commits
1 Commits
main
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffa89126d1 |
@@ -9,8 +9,12 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
# Replaced original einops.rearrange(...) call with torch.reshape(...) for slightly faster performance.
|
||||
# Original call: x = rearrange(x, "B H L D -> B L (H D)")
|
||||
# x = x.permute(0, 2, 1, 3) # BHLD -> BLHD
|
||||
# x = x.reshape(x.shape[0], x.shape[1], -1) # BLHD -> BL(HD)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
return x
|
||||
|
||||
|
||||
@@ -23,6 +27,9 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
# Replaced original einops.rearrange(...) call with torch.view(...) for slightly faster performance.
|
||||
# Original call: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
# out = out.view(*out.shape[:-1], 2, 2)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.math import attention, rope
|
||||
@@ -94,13 +93,14 @@ class SelfAttention(nn.Module):
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
# Unused code for reference:
|
||||
# def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
# qkv = self.qkv(x)
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
# q, k = self.norm(q, k, v)
|
||||
# x = attention(q, k, v, pe=pe)
|
||||
# x = self.proj(x)
|
||||
# return x
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -163,14 +163,22 @@ class DoubleStreamBlock(nn.Module):
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(
|
||||
2, 0, 3, 1, 4
|
||||
)
|
||||
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(
|
||||
2, 0, 3, 1, 4
|
||||
)
|
||||
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
@@ -229,7 +237,8 @@ class SingleStreamBlock(nn.Module):
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
|
||||
Reference in New Issue
Block a user