mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
match torch rmsnorm implementation (#6799)
* update rmsnorm to match torch implementation * run all tests * formatting * formatting * oneline * default to 1e-6 * restore old test * formatting * don't save elementwise_affine * your message * ignore webgpu --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -344,7 +344,7 @@ class TestNN(unittest.TestCase):
|
||||
torch_layer = torch.nn.LayerNorm([H, W]).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = LayerNorm([H, W])
|
||||
layer = LayerNorm((H, W))
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
||||
|
||||
@@ -446,7 +446,32 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
|
||||
def test_rmsnorm(self):
|
||||
class TorchRMSNorm(torch.nn.Module):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.RMSNorm([H, W], elementwise_affine=True, eps=1e-6).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = RMSNorm((H, W), elementwise_affine=True, eps=1e-6)
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
z.sum().backward()
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU can only run kernels with up to 10 buffers")
|
||||
def test_rmsnorm_llama(self):
|
||||
class LlamaRMSNorm(torch.nn.Module):
|
||||
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@@ -461,8 +486,8 @@ class TestNN(unittest.TestCase):
|
||||
return output * self.weight
|
||||
|
||||
B, T, embed_size = 4, 10, 20
|
||||
torch_layer = TorchRMSNorm(embed_size)
|
||||
layer = RMSNorm(embed_size)
|
||||
torch_layer = LlamaRMSNorm(embed_size)
|
||||
layer = RMSNorm(embed_size, elementwise_affine=True, eps=1e-6)
|
||||
layer.weight.requires_grad = True
|
||||
|
||||
for _ in range(10):
|
||||
|
||||
@@ -299,11 +299,15 @@ class RMSNorm:
|
||||
print(norm(t).numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
||||
def __init__(self, normalized_shape: int|tuple[int, ...], eps: float = 1e-6, elementwise_affine: bool = True):
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
|
||||
self.weight = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
||||
|
||||
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"Last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
x = x * (x.square().mean(axis=self.axis, keepdim=True) + self.eps).rsqrt()
|
||||
return x * self.weight if self.weight is not None else x
|
||||
|
||||
class Embedding:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user