tests from grad uop path [pr] (#8313)

This commit is contained in:
George Hotz
2024-12-18 09:25:05 -08:00
committed by GitHub
parent 6a1987f9f9
commit bd9c015b09
6 changed files with 55 additions and 42 deletions

View File

@@ -1,8 +1,8 @@
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, _to_np_dtype
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
mask = np.zeros(like.shape, dtype=_to_np_dtype(like.dtype)).reshape(-1)
mask[mask_inx] = mask_value
return mask.reshape(like.shape)
@@ -19,7 +19,8 @@ def jacobian(func, input):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
o_scalar = Tensor(mask_like(output, o, 1.)).mul(output).sum()
o_scalar = Tensor(mask_like(output, o, 1.)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
@@ -34,7 +35,7 @@ def numerical_jacobian(func, input, eps = 1e-3):
NJ = np.zeros((jo, ji), dtype=np.float32)
for i in range(ji):
eps_perturb = mask_like(input.numpy(), i, mask_value = eps)
eps_perturb = mask_like(input, i, mask_value = eps)
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)

View File

@@ -24,14 +24,9 @@ def compare_tiny_torch(model, model_torch, X, Y):
out = model(X)
loss = (out * Y).mean()
if not CI: print(loss.realize().numpy())
out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
if not CI: print(loss_torch.detach().numpy())
# assert losses match
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
# zero and backward
optimizer.zero_grad()
@@ -39,6 +34,11 @@ def compare_tiny_torch(model, model_torch, X, Y):
optimizer_torch.zero_grad()
loss_torch.backward()
# assert losses match
if not CI: print(loss.realize().numpy())
if not CI: print(loss_torch.detach().numpy())
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
for k,v in list(model_torch.named_parameters())[::-1]:
g = model_state_dict[k].grad.numpy()
gt = v.grad.detach().numpy()

View File

@@ -1716,6 +1716,8 @@ class TestHandCodedOpts(unittest.TestCase):
with Context(WINO=1):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
upcasts = []
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
@@ -1730,7 +1732,6 @@ class TestHandCodedOpts(unittest.TestCase):
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
out.mean().backward()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Kernel(si.ast)

View File

@@ -325,13 +325,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(BS, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -351,13 +351,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -377,13 +377,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -403,13 +403,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@@ -429,13 +429,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(N, C, D, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@@ -464,13 +464,13 @@ class TestNN(unittest.TestCase):
# forward
x = Tensor.randn(B, T, embed_size, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)

View File

@@ -67,15 +67,15 @@ class TestTinygrad(unittest.TestCase):
out = out.log_softmax()
out = out.mul(m).add(m).sum()
out.backward(retain_graph=True)
xgrad,wgrad = x.grad.numpy(), W.grad.numpy()
xgrad,wgrad = x.grad, W.grad
out.backward(retain_graph=True)
xgrad2,wgrad2 = x.grad.numpy(), W.grad.numpy()
xgrad2,wgrad2 = x.grad, W.grad
out.backward() # no need to retain again since we will not re-run backward
xgrad3,wgrad3 = x.grad.numpy(), W.grad.numpy()
np.testing.assert_allclose(xgrad3, xgrad * 3., atol=1e-6)
np.testing.assert_allclose(wgrad3, wgrad * 3., atol=1e-6)
np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6)
np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6)
xgrad3,wgrad3 = x.grad, W.grad
np.testing.assert_allclose(xgrad3.numpy(), xgrad.numpy() * 3., atol=1e-6)
np.testing.assert_allclose(wgrad3.numpy(), wgrad.numpy() * 3., atol=1e-6)
np.testing.assert_allclose(xgrad2.numpy(), xgrad.numpy() * 2., atol=1e-6)
np.testing.assert_allclose(wgrad2.numpy(), wgrad.numpy() * 2., atol=1e-6)
def test_second_order_backward_pass(self):
def test_pytorch():

View File

@@ -4,7 +4,7 @@ import jax
import jax.numpy as jnp
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
from tinygrad.ops import UOp, Ops
from tinygrad.gradient import compute_gradient
class TestGradient(unittest.TestCase):
@@ -93,5 +93,16 @@ class TestTensorGradient(unittest.TestCase):
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
class TestRealizeMeansRealize(unittest.TestCase):
def test_randn_realizes(self):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
self.assertEqual(x.lazydata.op, Ops.VIEW)
@unittest.expectedFailure
def test_uniform_realizes(self):
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
print(x.lazydata)
self.assertEqual(x.lazydata.op, Ops.VIEW)
if __name__ == '__main__':
unittest.main()