mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
clean up attentions in stable diffusion (#2275)
This commit is contained in:
@@ -175,7 +175,7 @@ class CrossAttention:
|
||||
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
@@ -348,29 +348,12 @@ class CLIPAttention:
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor, seq_len: int, bsz: int):
|
||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
src_len = key_states.shape[1]
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask)
|
||||
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.permute(0,2,1,3)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
|
||||
class CLIPEncoderLayer:
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user