Fix backward fn for < and == (#3037)

* fix no grad fn for < and ==

* remove 2 line breaks

* Remove deprecated autograd variable

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Guy Leroy
2024-01-15 04:39:52 +00:00
committed by GitHub
parent db965a0c74
commit 0dba34b81c
3 changed files with 13 additions and 1 deletions

View File

@@ -215,6 +215,11 @@ class TestOps(unittest.TestCase):
tt1 = Tensor.ones(4, requires_grad=True)
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 == tt2).sum().backward)
tt = Tensor.randn(4, requires_grad=True)
(tt*(tt == 0)).sum().backward()
t = torch.tensor(tt.numpy(), requires_grad=True)
(t*(t == 0)).sum().backward()
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_cmp_lt_backwards(self):
t1 = torch.ones(4, requires_grad=True)
@@ -223,6 +228,11 @@ class TestOps(unittest.TestCase):
tt1 = Tensor.ones(4, requires_grad=True)
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
tt = Tensor.randn(4, requires_grad=True)
(tt*(tt < 0)).sum().backward()
t = torch.tensor(tt.numpy(), requires_grad=True)
(t*(t < 0)).sum().backward()
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), atol=5e-4, rtol=1e-5)
#@unittest.skip("this is broken with contiguous")
def test_trunc(self):