diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 8561534c0b..f5d8f404be 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -175,7 +175,7 @@ class CrossAttention: q,k,v = self.to_q(x), self.to_k(context), self.to_v(context) q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) - h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)) + h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) return h_.sequential(self.to_out) class GEGLU: @@ -348,29 +348,12 @@ class CLIPAttention: self.q_proj = Linear(self.embed_dim, self.embed_dim) self.out_proj = Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor, seq_len: int, bsz: int): - return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3) - def __call__(self, hidden_states, causal_attention_mask): bsz, tgt_len, embed_dim = hidden_states.shape - - query_states = self.q_proj(hidden_states) - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - src_len = key_states.shape[1] - value_states = value_states.reshape(*proj_shape) - - attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask) - attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.permute(0,2,1,3) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output + q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] + attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) + return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) class CLIPEncoderLayer: def __init__(self):