clean up attentions in stable diffusion (#2275)

This commit is contained in:
chenyu
2023-11-11 14:25:36 -05:00
committed by GitHub
parent 453f48ce02
commit 5ef8d682e3

View File

@@ -175,7 +175,7 @@ class CrossAttention:
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
return h_.sequential(self.to_out)
class GEGLU:
@@ -348,29 +348,12 @@ class CLIPAttention:
self.q_proj = Linear(self.embed_dim, self.embed_dim)
self.out_proj = Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor, seq_len: int, bsz: int):
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
src_len = key_states.shape[1]
value_states = value_states.reshape(*proj_shape)
attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask)
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.permute(0,2,1,3)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
class CLIPEncoderLayer:
def __init__(self):