diff --git a/models/transformer.py b/models/transformer.py index 9a114920be..3f75feb4fe 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -27,6 +27,9 @@ class TransformerBlock: self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim)) self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim)) + self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) + self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) + def __call__(self, x): # bs x T x embed_dim bs = x.shape[0] @@ -47,9 +50,9 @@ class TransformerBlock: attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1) - x = layernorm(x, embed_dim) + x = layernorm(x, embed_dim).affine(self.ln1) x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1) - x = layernorm(x, embed_dim) + x = layernorm(x, embed_dim).affine(self.ln2) return x.reshape(shape=(bs, -1, embed_dim)) class Transformer: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7810621cb3..4557bbb885 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -276,7 +276,10 @@ class Tensor: return self._pool2d(*kernel_size).max(axis=(3,5)) def affine(self, params): - return self.dot(params[0]).add(params[1].reshape(shape=[1, -1])) + if len(params[0].shape) == 1: # elementwise affine + return self.mul(params[0].reshape(shape=[1, -1])).add(params[1].reshape(shape=[1, -1])) + else: + return self.dot(params[0]).add(params[1].reshape(shape=[1, -1])) # An instantiation of the Function is the Context class Function: