mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
LayerNorm2d for 2 lines
This commit is contained in:
@@ -73,11 +73,15 @@ class GroupNorm:
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
x = x.layernorm(eps=self.eps, axis=self.axis)
|
||||
if not self.elementwise_affine: return x
|
||||
return x * self.weight + self.bias
|
||||
|
||||
class LayerNorm2d(LayerNorm):
|
||||
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
Reference in New Issue
Block a user