mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
cleaner comments
This commit is contained in:
@@ -21,17 +21,18 @@ class TransformerBlock:
|
||||
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
|
||||
def attn(self, x):
|
||||
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
|
||||
query, key, value = [x.linear(*y) \
|
||||
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
|
||||
for y in [self.query, self.key, self.value]]
|
||||
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
|
||||
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
|
||||
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
|
||||
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
|
||||
|
||||
score = query.dot(key) * (1 / np.sqrt(self.head_size))
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
weights = score.softmax() # (bs, num_heads, time, time)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
|
||||
|
||||
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user