mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simpler tensor metadata mapping + tests [pr] (#9203)
* simpler tensor metadata mapping + tests [pr] * remove kernel metadata * don't map nones
This commit is contained in:
@@ -746,6 +746,22 @@ class TestInferenceMode(unittest.TestCase):
|
||||
|
||||
class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None: _METADATA.set(None)
|
||||
|
||||
# NOOPs are not included in kernel metadata
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.lazydata.metadata.name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
# we exclude const from kernel metadata because tensor methods can share the same CONST UOp
|
||||
def test_exclude_const_metadata(self):
|
||||
a = Tensor.arange(4)
|
||||
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
||||
sched = Tensor.schedule(a, b)
|
||||
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
||||
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
||||
|
||||
def test_matmul(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
|
||||
@@ -380,7 +380,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
# map tensors to new uops
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
ops_metadata: dict[UOp, Metadata] = {}
|
||||
for k,v in tensor_map.items():
|
||||
rev_tensor_map.setdefault(v, []).append(k)
|
||||
if k is v: continue
|
||||
@@ -391,9 +390,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
if k is not mop: becomes_map[k] = mop
|
||||
else: becomes_map[k] = v
|
||||
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
# if we're not realizing this tensor, map its metadata to the simplified uop
|
||||
elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
|
||||
|
||||
# map tensor metadata to simplified ops
|
||||
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
|
||||
# create kernels
|
||||
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
|
||||
sched_sink = kernel_map[sink]
|
||||
|
||||
Reference in New Issue
Block a user