diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 692ff44426..d8348d66bb 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -14,8 +14,6 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from extra.utils import download_file from tinygrad.state import torch_load, load_state_dict -# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code - class AttnBlock: def __init__(self, in_channels): self.norm = GroupNorm(32, in_channels) @@ -31,19 +29,8 @@ class AttnBlock: # compute attention b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = q @ k - w_ = w_ * (c**(-0.5)) - w_ = w_.softmax() - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) - h_ = v @ w_ - h_ = h_.reshape(b,c,h,w) - + q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)] + h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w) return x + self.proj_out(h_) class ResnetBlock: @@ -178,7 +165,6 @@ class CrossAttention: self.to_q = Linear(query_dim, n_heads*d_head, bias=False) self.to_k = Linear(context_dim, n_heads*d_head, bias=False) self.to_v = Linear(context_dim, n_heads*d_head, bias=False) - self.scale = d_head ** -0.5 self.num_heads = n_heads self.head_size = d_head self.to_out = [Linear(n_heads*d_head, query_dim)] @@ -186,14 +172,8 @@ class CrossAttention: def __call__(self, x, context=None): context = x if context is None else context q,k,v = self.to_q(x), self.to_k(context), self.to_v(context) - q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size) - k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time) - v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size) - - score = q.dot(k) * self.scale - weights = score.softmax() # (bs, num_heads, time, time) - attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size) - + q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] + attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)) return h_.sequential(self.to_out) @@ -362,7 +342,6 @@ class CLIPAttention: self.embed_dim = 768 self.num_heads = 12 self.head_dim = self.embed_dim // self.num_heads - self.scale = self.head_dim**-0.5 self.k_proj = Linear(self.embed_dim, self.embed_dim) self.v_proj = Linear(self.embed_dim, self.embed_dim) self.q_proj = Linear(self.embed_dim, self.embed_dim) @@ -374,7 +353,7 @@ class CLIPAttention: def __call__(self, hidden_states, causal_attention_mask): bsz, tgt_len, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) * self.scale + query_states = self.q_proj(hidden_states) key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) @@ -384,15 +363,8 @@ class CLIPAttention: src_len = key_states.shape[1] value_states = value_states.reshape(*proj_shape) - attn_weights = query_states @ key_states.permute(0,2,1) - - attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = attn_weights.softmax() - - attn_output = attn_weights @ value_states - + causal_attention_mask = causal_attention_mask.reshape(bsz * self.num_heads, tgt_len, src_len) + attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask) attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.permute(0,2,1,3) attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)