assert set equality in TestTensorMetadata [pr] (#7364)

This commit is contained in:
qazal
2024-10-29 13:29:29 +02:00
committed by GitHub
parent 0ebdb136e8
commit 7149eabb34

View File

@@ -709,41 +709,36 @@ class TestInferenceMode(unittest.TestCase):
f(x, m, W)
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None: _METADATA.set(None)
def test_matmul(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.lazydata.metadata.name, "matmul")
s = create_schedule([out.lazydata])
self.assertEqual(len(s[-1].metadata), 1)
self.assertEqual(s[-1].metadata[0].name, "matmul")
si = create_schedule([out.lazydata])[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
def test_relu(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.lazydata.metadata.name, "relu")
s = create_schedule([out.lazydata])
self.assertEqual(len(s[-1].metadata), 1)
self.assertEqual(s[-1].metadata[0].name, "relu")
si = create_schedule([out.lazydata])[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
def test_complex(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.lazydata.metadata.name, "__mul__")
self.assertEqual(out.lazydata.srcs[0].metadata.name, "relu")
self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid")
s = create_schedule([out.lazydata])
self.assertEqual(len(s[-1].metadata), 3)
self.assertEqual(s[-1].metadata[0].name, "relu")
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
self.assertEqual(s[-1].metadata[2].name, "__mul__")
si = create_schedule([out.lazydata])[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
def test_complex_backward(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
@@ -753,12 +748,12 @@ class TestTensorMetadata(unittest.TestCase):
self.assertTrue(x.grad.lazydata.metadata.backward)
self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid")
self.assertTrue(y.grad.lazydata.metadata.backward)
s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])
self.assertEqual(len(s[-1].metadata), 3)
self.assertEqual(s[-1].metadata[0].name, "sigmoid")
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
self.assertTrue(s[-1].metadata[1].backward)
self.assertEqual(s[-1].metadata[2].name, "relu")
si = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 1)
self.assertEqual(bw[0].name, "sigmoid")
if __name__ == '__main__':
unittest.main()