remove Tensor.no_grad, it's meaningless now [pr] (#10556)

This commit is contained in:
George Hotz
2025-05-28 22:20:02 -07:00
committed by GitHub
parent e4e7b5d7e1
commit b3b43a82c4
35 changed files with 17 additions and 80 deletions

View File

@@ -740,12 +740,11 @@ class TestInferenceMode(unittest.TestCase):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
with Tensor.test():
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
out = out.sum()
out.backward()
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
out = out.sum()
#out.backward()
assert x.grad is None
assert m.grad is None
assert tmp.grad is None
@@ -757,13 +756,12 @@ class TestInferenceMode(unittest.TestCase):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
@Tensor.test()
def f(x, m, W):
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
out = out.sum()
out.backward()
#out.backward()
assert x.grad is None
assert m.grad is None
assert tmp.grad is None