mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
rms matching pytorch implementation (#10319)
* rms matching pytorch implementation * pre commit fix --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -446,17 +446,18 @@ class TestNN(unittest.TestCase):
|
||||
def test_rmsnorm(self):
|
||||
class TorchRMSNorm(torch.nn.Module):
|
||||
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.ones(dim))
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.weight = torch.nn.Parameter(torch.ones(dim)) if elementwise_affine else None
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
return output if self.weight is None else output * self.weight
|
||||
|
||||
B, T, embed_size = 4, 10, 20
|
||||
torch_layer = TorchRMSNorm(embed_size)
|
||||
@@ -477,6 +478,22 @@ class TestNN(unittest.TestCase):
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
|
||||
|
||||
torch_layer = TorchRMSNorm(embed_size, elementwise_affine=False)
|
||||
layer = RMSNorm(embed_size, elementwise_affine=False)
|
||||
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(B, T, embed_size, requires_grad=True)
|
||||
z = layer(x)
|
||||
z.sum().backward()
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_embedding(self):
|
||||
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
||||
|
||||
|
||||
@@ -299,11 +299,15 @@ class RMSNorm:
|
||||
print(norm(t).numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
||||
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim) if elementwise_affine else None
|
||||
|
||||
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self._norm(x.float()).cast(x.dtype)
|
||||
return x if self.weight is None else x * self.weight
|
||||
|
||||
class Embedding:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user