LayerNorm2d for 2 lines

This commit is contained in:
George Hotz
2023-03-20 16:58:43 -07:00
parent 128ca160ac
commit d6f4219952
3 changed files with 31 additions and 8 deletions

View File

@@ -73,11 +73,15 @@ class GroupNorm:
class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
def __call__(self, x:Tensor):
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
x = x.layernorm(eps=self.eps, axis=self.axis)
if not self.elementwise_affine: return x
return x * self.weight + self.bias
class LayerNorm2d(LayerNorm):
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)