simpler tensor metadata mapping + tests [pr] (#9203)

* simpler tensor metadata mapping + tests [pr]

* remove kernel metadata

* don't map nones
This commit is contained in:
qazal
2025-02-22 21:18:46 +02:00
committed by GitHub
parent b711c6343a
commit 4578c3e8fd
2 changed files with 18 additions and 3 deletions

View File

@@ -746,6 +746,22 @@ class TestInferenceMode(unittest.TestCase):
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None: _METADATA.set(None)
# NOOPs are not included in kernel metadata
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.lazydata.metadata.name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
# we exclude const from kernel metadata because tensor methods can share the same CONST UOp
def test_exclude_const_metadata(self):
a = Tensor.arange(4)
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
sched = Tensor.schedule(a, b)
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
def test_matmul(self):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)