log and exp are first class ops

This commit is contained in:
George Hotz
2020-12-28 10:00:30 -05:00
parent ffff98db78
commit 593233b668
6 changed files with 56 additions and 54 deletions

View File

@@ -55,8 +55,7 @@ class TransformerBlock:
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
# TODO: this should be a normal softmax
weights = score.logsoftmax() # (bs, num_heads, T, T)
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3))
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
# layernorm