mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
pow(2) -> square in RMSNorm [run_process_replay] (#5901)
reads nicer in metadata
This commit is contained in:
@@ -174,14 +174,14 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
|
||||
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
||||
axis = tuple(range(2, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
|
||||
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
|
||||
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
|
||||
@@ -295,7 +295,7 @@ class RMSNorm:
|
||||
"""
|
||||
def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
||||
|
||||
def _norm(self, x:Tensor): return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user