simpler tensor metadata mapping + tests [pr] (#9203)

* simpler tensor metadata mapping + tests [pr]

* remove kernel metadata

* don't map nones
This commit is contained in:
qazal
2025-02-22 21:18:46 +02:00
committed by GitHub
parent b711c6343a
commit 4578c3e8fd
2 changed files with 18 additions and 3 deletions

View File

@@ -746,6 +746,22 @@ class TestInferenceMode(unittest.TestCase):
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None: _METADATA.set(None)
# NOOPs are not included in kernel metadata
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.lazydata.metadata.name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
# we exclude const from kernel metadata because tensor methods can share the same CONST UOp
def test_exclude_const_metadata(self):
a = Tensor.arange(4)
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
sched = Tensor.schedule(a, b)
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
def test_matmul(self):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)

View File

@@ -380,7 +380,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# map tensors to new uops
becomes_map: dict[UOp, UOp] = {}
rev_tensor_map: dict[UOp, list[UOp]] = {}
ops_metadata: dict[UOp, Metadata] = {}
for k,v in tensor_map.items():
rev_tensor_map.setdefault(v, []).append(k)
if k is v: continue
@@ -391,9 +390,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
if k is not mop: becomes_map[k] = mop
else: becomes_map[k] = v
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# if we're not realizing this tensor, map its metadata to the simplified uop
elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
# map tensor metadata to simplified ops
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
# create kernels
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
sched_sink = kernel_map[sink]