type annotation for layernorm (#6883)

This commit is contained in:
chenyu
2024-10-04 09:03:56 -04:00
committed by GitHub
parent 8ca506ee37
commit 4c3895744e

View File

@@ -3042,7 +3042,7 @@ class Tensor:
"""
return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
"""
Applies Layer Normalization over a mini-batch of inputs.