pow(2) -> square in RMSNorm [run_process_replay] (#5901)

reads nicer in metadata
This commit is contained in:
chenyu
2024-08-04 14:21:31 -04:00
committed by GitHub
parent aad9234e52
commit 996ff0c135
2 changed files with 3 additions and 3 deletions

View File

@@ -174,14 +174,14 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)

View File

@@ -295,7 +295,7 @@ class RMSNorm:
"""
def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
def _norm(self, x:Tensor): return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight