revert layernorm to have axis param

This commit is contained in:
George Hotz
2022-09-26 10:11:38 -04:00
parent dc80bf6f85
commit dec5334da9

View File

@@ -312,8 +312,7 @@ class Tensor:
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, eps=1e-5):
axis = range(1, len(self.shape))
def layernorm(self, axis=-1, eps=1e-5):
y = (self - self.mean(axis=axis, keepdim=True))
return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt())