clean up old tests (#12708)

This commit is contained in:
chenyu
2025-10-15 17:53:17 -04:00
committed by GitHub
parent b8cf35fb77
commit c3278e5622
4 changed files with 6 additions and 58 deletions

View File

@@ -837,7 +837,6 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("not accurate")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
@@ -849,11 +848,12 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.uop.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1]
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 1)
self.assertEqual(bw[0].name, "sigmoid")
self.assertEqual(len(bw), 2)
self.assertEqual(bw[0].name, "__mul__")
self.assertEqual(bw[1].name, "sigmoid")
class TestIdxUpcast(unittest.TestCase):
def _find_op(self, ast: UOp, op: Ops):