don't run unneeded grads

This commit is contained in:
George Hotz
2022-01-15 21:32:13 -08:00
parent ecc0903451
commit c0d1254003
2 changed files with 19 additions and 2 deletions

View File

@@ -16,7 +16,7 @@ class TestTinygrad(unittest.TestCase):
def test_tinygrad():
x = Tensor(x_init)
W = Tensor(W_init)
m = Tensor(m_init)
m = Tensor(m_init, requires_grad=False)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
@@ -64,6 +64,21 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_nograd(self):
x = Tensor(x_init, requires_grad=False)
m = Tensor(m_init, requires_grad=False)
W = Tensor(W_init, requires_grad=True)
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
out = out.sum()
out.backward()
assert x.grad is None
assert m.grad is None
assert tmp.grad is None
assert mm.grad is not None
assert W.grad is not None
def test_dropout(self):
Tensor.training = True
n, rate = 1_000_000, 0.1