diff --git a/test/test_tensor.py b/test/test_tensor.py index 1d6fc8175c..b35563ff1e 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -16,7 +16,7 @@ class TestTinygrad(unittest.TestCase): def test_tinygrad(): x = Tensor(x_init) W = Tensor(W_init) - m = Tensor(m_init) + m = Tensor(m_init, requires_grad=False) out = x.dot(W).relu() out = out.logsoftmax() out = out.mul(m).add(m).sum() @@ -64,6 +64,21 @@ class TestTinygrad(unittest.TestCase): for x,y in zip(test_tinygrad(), test_pytorch()): np.testing.assert_allclose(x, y, atol=1e-5) + def test_nograd(self): + x = Tensor(x_init, requires_grad=False) + m = Tensor(m_init, requires_grad=False) + W = Tensor(W_init, requires_grad=True) + tmp = x.mul(m) + mm = tmp.matmul(W) + out = mm.relu() + out = out.sum() + out.backward() + assert x.grad is None + assert m.grad is None + assert tmp.grad is None + assert mm.grad is not None + assert W.grad is not None + def test_dropout(self): Tensor.training = True n, rate = 1_000_000, 0.1 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 93220dcc2a..06a9d79218 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -127,12 +127,14 @@ class Tensor: for t0 in reversed(self.deepwalk()): assert (t0.grad is not None) + if not any([x.requires_grad for x in t0._ctx.parents]): + continue with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po: grads = t0._ctx.backward(t0._ctx, t0.grad.data) if len(t0._ctx.parents) == 1: grads = [grads] for t, g in zip(t0._ctx.parents, grads): - if g is not None: + if g is not None and t.requires_grad: assert g.shape == t.shape, \ f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}" gt = Tensor(g, device=self.device, requires_grad=False)