mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Fix backward fn for < and == (#3037)
* fix no grad fn for < and == * remove 2 line breaks * Remove deprecated autograd variable --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -215,6 +215,11 @@ class TestOps(unittest.TestCase):
|
||||
tt1 = Tensor.ones(4, requires_grad=True)
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 == tt2).sum().backward)
|
||||
tt = Tensor.randn(4, requires_grad=True)
|
||||
(tt*(tt == 0)).sum().backward()
|
||||
t = torch.tensor(tt.numpy(), requires_grad=True)
|
||||
(t*(t == 0)).sum().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_cmp_lt_backwards(self):
|
||||
t1 = torch.ones(4, requires_grad=True)
|
||||
@@ -223,6 +228,11 @@ class TestOps(unittest.TestCase):
|
||||
tt1 = Tensor.ones(4, requires_grad=True)
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
|
||||
tt = Tensor.randn(4, requires_grad=True)
|
||||
(tt*(tt < 0)).sum().backward()
|
||||
t = torch.tensor(tt.numpy(), requires_grad=True)
|
||||
(t*(t < 0)).sum().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
#@unittest.skip("this is broken with contiguous")
|
||||
def test_trunc(self):
|
||||
|
||||
Reference in New Issue
Block a user