match torch rmsnorm implementation (#6799)

* update rmsnorm to match torch implementation

* run all tests

* formatting

* formatting

* oneline

* default to 1e-6

* restore old test

* formatting

* don't save elementwise_affine

* your message

* ignore webgpu

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ryan Dorrington
2025-02-13 15:02:51 +10:00
committed by GitHub
parent 19ae829bd1
commit a66b8250e0
2 changed files with 37 additions and 8 deletions

View File

@@ -344,7 +344,7 @@ class TestNN(unittest.TestCase):
torch_layer = torch.nn.LayerNorm([H, W]).eval()
# create in tinygrad
layer = LayerNorm([H, W])
layer = LayerNorm((H, W))
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
@@ -446,7 +446,32 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
def test_rmsnorm(self):
class TorchRMSNorm(torch.nn.Module):
N, C, H, W = 20, 5, 10, 10
# create in torch
torch_layer = torch.nn.RMSNorm([H, W], elementwise_affine=True, eps=1e-6).eval()
# create in tinygrad
layer = RMSNorm((H, W), elementwise_affine=True, eps=1e-6)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
torch_z.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
def test_rmsnorm_llama(self):
class LlamaRMSNorm(torch.nn.Module):
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@@ -461,8 +486,8 @@ class TestNN(unittest.TestCase):
return output * self.weight
B, T, embed_size = 4, 10, 20
torch_layer = TorchRMSNorm(embed_size)
layer = RMSNorm(embed_size)
torch_layer = LlamaRMSNorm(embed_size)
layer = RMSNorm(embed_size, elementwise_affine=True, eps=1e-6)
layer.weight.requires_grad = True
for _ in range(10):

View File

@@ -299,11 +299,15 @@ class RMSNorm:
print(norm(t).numpy())
```
"""
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
def __init__(self, normalized_shape: int|tuple[int, ...], eps: float = 1e-6, elementwise_affine: bool = True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
self.weight = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
def __call__(self, x: Tensor) -> Tensor:
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"Last dimensions of {x.shape} must match {self.normalized_shape}"
x = x * (x.square().mean(axis=self.axis, keepdim=True) + self.eps).rsqrt()
return x * self.weight if self.weight is not None else x
class Embedding:
"""