Transformer: use Tensor.scaled_dot_product_attention (#1520)

This commit is contained in:
Jacky Lee
2023-08-11 09:00:37 -07:00
committed by GitHub
parent 38fe84d92b
commit 2e85fce068

View File

@@ -24,18 +24,8 @@ class TransformerBlock:
def attn(self, x):
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
query, key, value = [x.linear(*y) \
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query, self.key, self.value]]
query = query.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
key = key.permute(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
value = value.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(value).permute(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]]
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2)
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
def __call__(self, x):