rms matching pytorch implementation (#10319)

* rms matching pytorch implementation

* pre commit fix

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
वेदांत
2025-05-17 20:53:11 +05:30
committed by GitHub
parent da2b1834b4
commit 2453d99050
2 changed files with 26 additions and 5 deletions

View File

@@ -446,17 +446,18 @@ class TestNN(unittest.TestCase):
def test_rmsnorm(self):
class TorchRMSNorm(torch.nn.Module):
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
self.elementwise_affine = elementwise_affine
self.weight = torch.nn.Parameter(torch.ones(dim)) if elementwise_affine else None
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return output if self.weight is None else output * self.weight
B, T, embed_size = 4, 10, 20
torch_layer = TorchRMSNorm(embed_size)
@@ -477,6 +478,22 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
torch_layer = TorchRMSNorm(embed_size, elementwise_affine=False)
layer = RMSNorm(embed_size, elementwise_affine=False)
for _ in range(10):
# forward
x = Tensor.randn(B, T, embed_size, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28

View File

@@ -299,11 +299,15 @@ class RMSNorm:
print(norm(t).numpy())
```
"""
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
self.eps = eps
self.weight = Tensor.ones(dim) if elementwise_affine else None
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
def __call__(self, x:Tensor) -> Tensor:
x = self._norm(x.float()).cast(x.dtype)
return x if self.weight is None else x * self.weight
class Embedding:
"""