diff --git a/test/test_tensor.py b/test/test_tensor.py index 7a7d21dabf..d07074688a 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -746,6 +746,22 @@ class TestInferenceMode(unittest.TestCase): class TestTensorMetadata(unittest.TestCase): def setUp(self) -> None: _METADATA.set(None) + + # NOOPs are not included in kernel metadata + def test_exclude_noop_metadata(self): + a = Tensor.rand(4, 4)*1 + self.assertEqual(a.lazydata.metadata.name, "__mul__") + k = a.schedule()[-1] + self.assertEqual([m.name for m in k.metadata], ["rand"]) + + # we exclude const from kernel metadata because tensor methods can share the same CONST UOp + def test_exclude_const_metadata(self): + a = Tensor.arange(4) + b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous() + sched = Tensor.schedule(a, b) + self.assertEqual([m.name for m in sched[0].metadata], ["arange"]) + self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"]) + def test_matmul(self): x = Tensor.rand(3, requires_grad=True) W = Tensor.rand(3, 3, requires_grad=True) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index cf6c74dc8a..8d67b885ff 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -380,7 +380,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # map tensors to new uops becomes_map: dict[UOp, UOp] = {} rev_tensor_map: dict[UOp, list[UOp]] = {} - ops_metadata: dict[UOp, Metadata] = {} for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) if k is v: continue @@ -391,9 +390,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va if k is not mop: becomes_map[k] = mop else: becomes_map[k] = v elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v - # if we're not realizing this tensor, map its metadata to the simplified uop - elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata + # map tensor metadata to simplified ops + ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)} # create kernels kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True) sched_sink = kernel_map[sink]