diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index d8348d66bb..692ff44426 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -14,6 +14,8 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from extra.utils import download_file from tinygrad.state import torch_load, load_state_dict +# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code + class AttnBlock: def __init__(self, in_channels): self.norm = GroupNorm(32, in_channels) @@ -29,8 +31,19 @@ class AttnBlock: # compute attention b,c,h,w = q.shape - q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)] - h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w) + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = q @ k + w_ = w_ * (c**(-0.5)) + w_ = w_.softmax() + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) + h_ = v @ w_ + h_ = h_.reshape(b,c,h,w) + return x + self.proj_out(h_) class ResnetBlock: @@ -165,6 +178,7 @@ class CrossAttention: self.to_q = Linear(query_dim, n_heads*d_head, bias=False) self.to_k = Linear(context_dim, n_heads*d_head, bias=False) self.to_v = Linear(context_dim, n_heads*d_head, bias=False) + self.scale = d_head ** -0.5 self.num_heads = n_heads self.head_size = d_head self.to_out = [Linear(n_heads*d_head, query_dim)] @@ -172,8 +186,14 @@ class CrossAttention: def __call__(self, x, context=None): context = x if context is None else context q,k,v = self.to_q(x), self.to_k(context), self.to_v(context) - q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] - attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) + q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size) + k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time) + v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size) + + score = q.dot(k) * self.scale + weights = score.softmax() # (bs, num_heads, time, time) + attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size) + h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)) return h_.sequential(self.to_out) @@ -342,6 +362,7 @@ class CLIPAttention: self.embed_dim = 768 self.num_heads = 12 self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 self.k_proj = Linear(self.embed_dim, self.embed_dim) self.v_proj = Linear(self.embed_dim, self.embed_dim) self.q_proj = Linear(self.embed_dim, self.embed_dim) @@ -353,7 +374,7 @@ class CLIPAttention: def __call__(self, hidden_states, causal_attention_mask): bsz, tgt_len, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) + query_states = self.q_proj(hidden_states) * self.scale key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) @@ -363,8 +384,15 @@ class CLIPAttention: src_len = key_states.shape[1] value_states = value_states.reshape(*proj_shape) - causal_attention_mask = causal_attention_mask.reshape(bsz * self.num_heads, tgt_len, src_len) - attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask) + attn_weights = query_states @ key_states.permute(0,2,1) + + attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = attn_weights.softmax() + + attn_output = attn_weights @ value_states + attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.permute(0,2,1,3) attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)