tests from grad uop path [pr] (#8313)

This commit is contained in:
George Hotz
2024-12-18 09:25:05 -08:00
committed by GitHub
parent 6a1987f9f9
commit bd9c015b09
6 changed files with 55 additions and 42 deletions

View File

@@ -325,13 +325,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(BS, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -351,13 +351,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -377,13 +377,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -403,13 +403,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@@ -429,13 +429,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, D, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@@ -464,13 +464,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(B, T, embed_size, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)