Revert "match torch rmsnorm implementation (#6799)" (#9052)

This reverts commit a66b8250e0.
This commit is contained in:
George Hotz
2025-02-13 14:42:45 +08:00
committed by GitHub
parent a66b8250e0
commit 33a1151f2f
2 changed files with 8 additions and 37 deletions

View File

@@ -344,7 +344,7 @@ class TestNN(unittest.TestCase):
torch_layer = torch.nn.LayerNorm([H, W]).eval()
# create in tinygrad
layer = LayerNorm((H, W))
layer = LayerNorm([H, W])
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
@@ -446,32 +446,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
def test_rmsnorm(self):
N, C, H, W = 20, 5, 10, 10
# create in torch
torch_layer = torch.nn.RMSNorm([H, W], elementwise_affine=True, eps=1e-6).eval()
# create in tinygrad
layer = RMSNorm((H, W), elementwise_affine=True, eps=1e-6)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
torch_z.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
def test_rmsnorm_llama(self):
class LlamaRMSNorm(torch.nn.Module):
class TorchRMSNorm(torch.nn.Module):
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@@ -486,8 +461,8 @@ class TestNN(unittest.TestCase):
return output * self.weight
B, T, embed_size = 4, 10, 20
torch_layer = LlamaRMSNorm(embed_size)
layer = RMSNorm(embed_size, elementwise_affine=True, eps=1e-6)
torch_layer = TorchRMSNorm(embed_size)
layer = RMSNorm(embed_size)
layer.weight.requires_grad = True
for _ in range(10):

View File

@@ -299,15 +299,11 @@ class RMSNorm:
print(norm(t).numpy())
```
"""
def __init__(self, normalized_shape: int|tuple[int, ...], eps: float = 1e-6, elementwise_affine: bool = True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
self.weight = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
def __call__(self, x: Tensor) -> Tensor:
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"Last dimensions of {x.shape} must match {self.normalized_shape}"
x = x * (x.square().mean(axis=self.axis, keepdim=True) + self.eps).rsqrt()
return x * self.weight if self.weight is not None else x
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
class Embedding:
"""