From bd9c015b09817be7457e1b33eff0548e072e93f0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:25:05 -0800 Subject: [PATCH] tests from grad uop path [pr] (#8313) --- extra/gradcheck.py | 9 +++---- test/models/test_end2end.py | 10 ++++---- test/test_linearizer.py | 3 ++- test/test_nn.py | 48 ++++++++++++++++++------------------- test/test_tensor.py | 14 +++++------ test/unit/test_gradient.py | 13 +++++++++- 6 files changed, 55 insertions(+), 42 deletions(-) diff --git a/extra/gradcheck.py b/extra/gradcheck.py index 4d99726cce..a94befac39 100644 --- a/extra/gradcheck.py +++ b/extra/gradcheck.py @@ -1,8 +1,8 @@ import numpy as np -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, _to_np_dtype def mask_like(like, mask_inx, mask_value = 1.0): - mask = np.zeros_like(like).reshape(-1) + mask = np.zeros(like.shape, dtype=_to_np_dtype(like.dtype)).reshape(-1) mask[mask_inx] = mask_value return mask.reshape(like.shape) @@ -19,7 +19,8 @@ def jacobian(func, input): # tinygrad doesn't support slicing, tiny-hack to select # the needed scalar an backpropagate only through it - o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum() + o_scalar = Tensor(mask_like(output, o, 1.)).mul(output).sum() + o_scalar = Tensor(mask_like(output, o, 1.)).mul(output).sum() o_scalar.backward() for i, grad in enumerate(input.grad.numpy().reshape(-1)): @@ -34,7 +35,7 @@ def numerical_jacobian(func, input, eps = 1e-3): NJ = np.zeros((jo, ji), dtype=np.float32) for i in range(ji): - eps_perturb = mask_like(input.numpy(), i, mask_value = eps) + eps_perturb = mask_like(input, i, mask_value = eps) output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1) output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1) diff --git a/test/models/test_end2end.py b/test/models/test_end2end.py index 9c84d83d0b..82dc8cc7a6 100644 --- a/test/models/test_end2end.py +++ b/test/models/test_end2end.py @@ -24,14 +24,9 @@ def compare_tiny_torch(model, model_torch, X, Y): out = model(X) loss = (out * Y).mean() - if not CI: print(loss.realize().numpy()) out_torch = model_torch(torch.Tensor(X.numpy())) loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() - if not CI: print(loss_torch.detach().numpy()) - - # assert losses match - np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) # zero and backward optimizer.zero_grad() @@ -39,6 +34,11 @@ def compare_tiny_torch(model, model_torch, X, Y): optimizer_torch.zero_grad() loss_torch.backward() + # assert losses match + if not CI: print(loss.realize().numpy()) + if not CI: print(loss_torch.detach().numpy()) + np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) + for k,v in list(model_torch.named_parameters())[::-1]: g = model_state_dict[k].grad.numpy() gt = v.grad.detach().numpy() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index de7d64e31f..c7c441ebca 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1716,6 +1716,8 @@ class TestHandCodedOpts(unittest.TestCase): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() out = Tensor.conv2d(x,w, padding=1) + out.mean().backward() + upcasts = [] wino_schedule = create_schedule([out.lazydata]) # collect upcasts of tile transform kernels @@ -1730,7 +1732,6 @@ class TestHandCodedOpts(unittest.TestCase): # this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1 - out.mean().backward() backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata]) for si in backward_schedule: k = Kernel(si.ast) diff --git a/test/test_nn.py b/test/test_nn.py index 198b0eb1a6..3e45d5c6f9 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -325,13 +325,13 @@ class TestNN(unittest.TestCase): # forward x = Tensor.randn(BS, C, H, W, requires_grad=True) z = layer(x) + z.sum().backward() + torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - - # backward - z.sum().backward() torch_z.sum().backward(retain_graph=True) + + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) @@ -351,13 +351,13 @@ class TestNN(unittest.TestCase): # forward x = Tensor.randn(N, C, H, W, requires_grad=True) z = layer(x) + z.sum().backward() + torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - - # backward - z.sum().backward() torch_z.sum().backward(retain_graph=True) + + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) @@ -377,13 +377,13 @@ class TestNN(unittest.TestCase): # forward x = Tensor.randn(N, C, H, W, requires_grad=True) z = layer(x) + z.sum().backward() + torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - - # backward - z.sum().backward() torch_z.sum().backward(retain_graph=True) + + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) @@ -403,13 +403,13 @@ class TestNN(unittest.TestCase): # forward x = Tensor.randn(N, C, H, W, requires_grad=True) z = layer(x) + z.sum().backward() + torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - - # backward - z.sum().backward() torch_z.sum().backward(retain_graph=True) + + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) @@ -429,13 +429,13 @@ class TestNN(unittest.TestCase): # forward x = Tensor.randn(N, C, D, H, W, requires_grad=True) z = layer(x) + z.sum().backward() + torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - - # backward - z.sum().backward() torch_z.sum().backward(retain_graph=True) + + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) @@ -464,13 +464,13 @@ class TestNN(unittest.TestCase): # forward x = Tensor.randn(B, T, embed_size, requires_grad=True) z = layer(x) + z.sum().backward() + torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - - # backward - z.sum().backward() torch_z.sum().backward(retain_graph=True) + + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3) diff --git a/test/test_tensor.py b/test/test_tensor.py index 7ab6c6354a..3a5853e2c5 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -67,15 +67,15 @@ class TestTinygrad(unittest.TestCase): out = out.log_softmax() out = out.mul(m).add(m).sum() out.backward(retain_graph=True) - xgrad,wgrad = x.grad.numpy(), W.grad.numpy() + xgrad,wgrad = x.grad, W.grad out.backward(retain_graph=True) - xgrad2,wgrad2 = x.grad.numpy(), W.grad.numpy() + xgrad2,wgrad2 = x.grad, W.grad out.backward() # no need to retain again since we will not re-run backward - xgrad3,wgrad3 = x.grad.numpy(), W.grad.numpy() - np.testing.assert_allclose(xgrad3, xgrad * 3., atol=1e-6) - np.testing.assert_allclose(wgrad3, wgrad * 3., atol=1e-6) - np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6) - np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6) + xgrad3,wgrad3 = x.grad, W.grad + np.testing.assert_allclose(xgrad3.numpy(), xgrad.numpy() * 3., atol=1e-6) + np.testing.assert_allclose(wgrad3.numpy(), wgrad.numpy() * 3., atol=1e-6) + np.testing.assert_allclose(xgrad2.numpy(), xgrad.numpy() * 2., atol=1e-6) + np.testing.assert_allclose(wgrad2.numpy(), wgrad.numpy() * 2., atol=1e-6) def test_second_order_backward_pass(self): def test_pytorch(): diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index bb503f2f1e..21182874f2 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp from tinygrad import Tensor from tinygrad.dtype import dtypes -from tinygrad.ops import UOp +from tinygrad.ops import UOp, Ops from tinygrad.gradient import compute_gradient class TestGradient(unittest.TestCase): @@ -93,5 +93,16 @@ class TestTensorGradient(unittest.TestCase): dx = z.gradient(x, gradient=dz)[0] self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0]) +class TestRealizeMeansRealize(unittest.TestCase): + def test_randn_realizes(self): + x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() + self.assertEqual(x.lazydata.op, Ops.VIEW) + + @unittest.expectedFailure + def test_uniform_realizes(self): + x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize() + print(x.lazydata) + self.assertEqual(x.lazydata.op, Ops.VIEW) + if __name__ == '__main__': unittest.main()