mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
simpler tensor metadata mapping + tests [pr] (#9203)
* simpler tensor metadata mapping + tests [pr] * remove kernel metadata * don't map nones
This commit is contained in:
@@ -746,6 +746,22 @@ class TestInferenceMode(unittest.TestCase):
|
||||
|
||||
class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None: _METADATA.set(None)
|
||||
|
||||
# NOOPs are not included in kernel metadata
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.lazydata.metadata.name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
# we exclude const from kernel metadata because tensor methods can share the same CONST UOp
|
||||
def test_exclude_const_metadata(self):
|
||||
a = Tensor.arange(4)
|
||||
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
||||
sched = Tensor.schedule(a, b)
|
||||
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
||||
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
||||
|
||||
def test_matmul(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
|
||||
Reference in New Issue
Block a user