mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
don't run unneeded grads
This commit is contained in:
@@ -16,7 +16,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init)
|
||||
W = Tensor(W_init)
|
||||
m = Tensor(m_init)
|
||||
m = Tensor(m_init, requires_grad=False)
|
||||
out = x.dot(W).relu()
|
||||
out = out.logsoftmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
@@ -64,6 +64,21 @@ class TestTinygrad(unittest.TestCase):
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_nograd(self):
|
||||
x = Tensor(x_init, requires_grad=False)
|
||||
m = Tensor(m_init, requires_grad=False)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
assert mm.grad is not None
|
||||
assert W.grad is not None
|
||||
|
||||
def test_dropout(self):
|
||||
Tensor.training = True
|
||||
n, rate = 1_000_000, 0.1
|
||||
|
||||
Reference in New Issue
Block a user