mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
clean up old tests (#12708)
This commit is contained in:
@@ -837,7 +837,6 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@unittest.skip("not accurate")
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
@@ -849,11 +848,12 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
|
||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
self.assertEqual(len(bw), 1)
|
||||
self.assertEqual(bw[0].name, "sigmoid")
|
||||
self.assertEqual(len(bw), 2)
|
||||
self.assertEqual(bw[0].name, "__mul__")
|
||||
self.assertEqual(bw[1].name, "sigmoid")
|
||||
|
||||
class TestIdxUpcast(unittest.TestCase):
|
||||
def _find_op(self, ast: UOp, op: Ops):
|
||||
|
||||
Reference in New Issue
Block a user