mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove Tensor.no_grad, it's meaningless now [pr] (#10556)
This commit is contained in:
@@ -740,12 +740,11 @@ class TestInferenceMode(unittest.TestCase):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
with Tensor.test():
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
#out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
@@ -757,13 +756,12 @@ class TestInferenceMode(unittest.TestCase):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
@Tensor.test()
|
||||
def f(x, m, W):
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
#out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
|
||||
Reference in New Issue
Block a user