uops can have multiple metadata (#10479)

* uops can have multiple metadata

* fixups
This commit is contained in:
George Hotz
2025-05-22 21:35:02 -07:00
committed by GitHub
parent 283586bb96
commit 1e4d63e06e
5 changed files with 26 additions and 26 deletions

View File

@@ -775,9 +775,10 @@ class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None: _METADATA.set(None)
# NOOPs are not included in kernel metadata
@unittest.skip("why would this be true?")
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.lazydata.metadata.name, "__mul__")
self.assertEqual(a.lazydata.metadata[0].name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
@@ -794,7 +795,7 @@ class TestTensorMetadata(unittest.TestCase):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.lazydata.metadata.name, "matmul")
self.assertEqual(out.lazydata.metadata[0].name, "matmul")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
@@ -802,7 +803,7 @@ class TestTensorMetadata(unittest.TestCase):
def test_relu(self):
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.lazydata.metadata.name, "relu")
self.assertEqual(out.lazydata.metadata[0].name, "relu")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
@@ -811,9 +812,9 @@ class TestTensorMetadata(unittest.TestCase):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.lazydata.metadata.name, "__mul__")
self.assertEqual(out.lazydata.src[0].metadata.name, "relu")
self.assertEqual(out.lazydata.src[1].metadata.name, "sigmoid")
self.assertEqual(out.lazydata.metadata[0].name, "__mul__")
self.assertEqual(out.lazydata.src[0].metadata[0].name, "relu")
self.assertEqual(out.lazydata.src[1].metadata[0].name, "sigmoid")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@@ -822,17 +823,17 @@ class TestTensorMetadata(unittest.TestCase):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.lazydata.metadata.name, "sum")
self.assertEqual(out.lazydata.metadata[0].name, "sum")
out.backward()
self.assertEqual(x.grad.lazydata.metadata.name, "relu")
self.assertTrue(x.grad.lazydata.metadata.backward)
self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid")
self.assertTrue(y.grad.lazydata.metadata.backward)
self.assertEqual(x.grad.lazydata.metadata[0].name, "relu")
self.assertTrue(x.grad.lazydata.metadata[0].backward)
self.assertEqual(y.grad.lazydata.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.lazydata.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1]
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"})
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 1)
self.assertEqual(len(bw), 2)
self.assertEqual(bw[0].name, "sigmoid")
class TestIdxUpcast(unittest.TestCase):