From c8b569a8c74ffc1e3fd21cda2bca04749e7c500a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 14 May 2022 21:28:39 -0700 Subject: [PATCH] cleaner comments --- models/transformer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/models/transformer.py b/models/transformer.py index d21be4e056..74a7e2674e 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -21,17 +21,18 @@ class TransformerBlock: self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) def attn(self, x): + # x: (bs, time, embed_dim) -> (bs, time, embed_dim) query, key, value = [x.linear(*y) \ .reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \ for y in [self.query, self.key, self.value]] - query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size) - key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T) - value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size) + query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size) + key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time) + value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size) score = query.dot(key) * (1 / np.sqrt(self.head_size)) - weights = score.softmax() # (bs, num_heads, T, T) - attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) + weights = score.softmax() # (bs, num_heads, time, time) + attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size) return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)