revert metadata with graph_rewrite (#7353) (#7362)

This reverts commit 540e4179e7.
This commit is contained in:
qazal
2024-10-29 13:16:31 +02:00
committed by GitHub
parent f2044cfb22
commit 0ebdb136e8
2 changed files with 14 additions and 18 deletions

View File

@@ -738,9 +738,9 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid")
s = create_schedule([out.lazydata])
self.assertEqual(len(s[-1].metadata), 3)
self.assertEqual(s[-1].metadata[0].name, "__mul__")
self.assertEqual(s[-1].metadata[1].name, "relu")
self.assertEqual(s[-1].metadata[2].name, "sigmoid")
self.assertEqual(s[-1].metadata[0].name, "relu")
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
self.assertEqual(s[-1].metadata[2].name, "__mul__")
def test_complex_backward(self):
_METADATA.set(None)
@@ -757,7 +757,7 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(s[-1].metadata), 3)
self.assertEqual(s[-1].metadata[0].name, "sigmoid")
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
self.assertTrue(s[-1].metadata[0].backward)
self.assertTrue(s[-1].metadata[1].backward)
self.assertEqual(s[-1].metadata[2].name, "relu")
if __name__ == '__main__':