layernorm with learnable parameters

This commit is contained in:
George Hotz
2021-11-29 13:03:57 -05:00
parent c7f795ca1e
commit 1eafa5580e
2 changed files with 9 additions and 3 deletions

View File

@@ -27,6 +27,9 @@ class TransformerBlock:
self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
def __call__(self, x):
# bs x T x embed_dim
bs = x.shape[0]
@@ -47,9 +50,9 @@ class TransformerBlock:
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1)
x = layernorm(x, embed_dim)
x = layernorm(x, embed_dim).affine(self.ln1)
x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1)
x = layernorm(x, embed_dim)
x = layernorm(x, embed_dim).affine(self.ln2)
return x.reshape(shape=(bs, -1, embed_dim))
class Transformer:

View File

@@ -276,7 +276,10 @@ class Tensor:
return self._pool2d(*kernel_size).max(axis=(3,5))
def affine(self, params):
return self.dot(params[0]).add(params[1].reshape(shape=[1, -1]))
if len(params[0].shape) == 1: # elementwise affine
return self.mul(params[0].reshape(shape=[1, -1])).add(params[1].reshape(shape=[1, -1]))
else:
return self.dot(params[0]).add(params[1].reshape(shape=[1, -1]))
# An instantiation of the Function is the Context
class Function: