diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 87555c4ef4..82e6ecf4f3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -312,8 +312,7 @@ class Tensor: def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self) - def layernorm(self, eps=1e-5): - axis = range(1, len(self.shape)) + def layernorm(self, axis=-1, eps=1e-5): y = (self - self.mean(axis=axis, keepdim=True)) return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt())