From 996ff0c1357b554e83ce8f720f960fec2547bb98 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 4 Aug 2024 14:21:31 -0400 Subject: [PATCH] pow(2) -> square in RMSNorm [run_process_replay] (#5901) reads nicer in metadata --- extra/onnx_ops.py | 4 ++-- tinygrad/nn/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 9972775380..82db4d6f08 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -174,14 +174,14 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05 def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05): axis = tuple(range(2, x.ndim)) mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1): assert stash_type == 1, "only float32 is supported" axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) mean = x.mean(axis=axis, keepdim=True) - return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05): return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 6f62a4d1aa..0199c64ef3 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -295,7 +295,7 @@ class RMSNorm: """ def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim) - def _norm(self, x:Tensor): return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() + def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt() def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight