diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index feb9b33824..0eb8345c6f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3042,7 +3042,7 @@ class Tensor: """ return functools.reduce(lambda x,f: f(x), ll, self) - def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor: + def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor: """ Applies Layer Normalization over a mini-batch of inputs.