add more rangeify pm tests

This commit is contained in:
George Hotz
2025-10-07 17:20:39 +08:00
parent 403fdfcfd4
commit 9de19ef89f

View File

@@ -302,18 +302,53 @@ class TestOuterworld(unittest.TestCase):
from tinygrad.schedule.rangeify import pm_rangeify, RangeifyContext
class TestRangeifyPM(unittest.TestCase):
@unittest.expectedFailure
def test_reshape_match(self):
def proc(a:Tensor):
sink = a.uop.sink()
def setUp(self): self.base = Tensor.empty(10*10).reshape(10, 10).contiguous()
def assert_same(self, a, b):
def run_pm_rangeify(t:Tensor):
sink = t.uop.sink()
pm_realize = PatternMatcher([(UPat(Ops.CONTIGUOUS, name="x"), lambda x: x.replace(op=Ops.REALIZE))])
sink = graph_rewrite(sink, pm_realize)
return graph_rewrite(sink, pm_rangeify, ctx=RangeifyContext())
a = Tensor.empty(10*10).reshape(10, 10).contiguous().pad(((0,0),(0,1))).contiguous()
b = Tensor.empty(10*10).reshape(10, 10).contiguous().reshape(100).reshape(10, 10).pad(((0,0),(0,1))).contiguous()
sink1 = proc(a)
sink2 = proc(b)
self.assertIs(sink1, sink2)
self.assertIs(run_pm_rangeify(a.contiguous()), run_pm_rangeify(b.contiguous()))
def test_nothing_match(self):
a = self.base.pad(((0,0),(0,1)))
b = self.base.pad(((0,0),(0,1)))
self.assert_same(a, b)
def test_reshape_match(self):
a = self.base
b = self.base.reshape(100).reshape(10, 10)
self.assert_same(a, b)
def test_permute_reshape_match(self):
a = self.base
b = self.base.permute(1,0).reshape(100).reshape(10, 10).permute(1,0)
self.assert_same(a, b)
def test_padded_permute_match(self):
a = self.base.pad(((0,0),(0,1)))
b = self.base.permute(1,0).pad(((0,1),(0,0))).permute(1,0)
self.assert_same(a, b)
@unittest.expectedFailure
def test_padded_reshape_match(self):
a = self.base.pad(((0,0),(0,1)))
b = self.base.reshape(100).reshape(10, 10).pad(((0,0),(0,1)))
self.assert_same(a, b)
@unittest.expectedFailure
def test_padded_permute_reshape_match(self):
a = self.base.pad(((0,0),(0,1)))
b = self.base.permute(1,0).reshape(100).reshape(10, 10).pad(((0,1),(0,0))).permute(1,0)
self.assert_same(a, b)
# why is this failing?
@unittest.expectedFailure
def test_cross_pad_match(self):
a = self.base.pad(((0,0),(0,1))).pad(((0,1),(0,0)))
b = self.base.pad(((0,1),(0,0))).pad(((0,0),(0,1)))
self.assert_same(a, b)
class TestRangeifyEdgeCase(unittest.TestCase):
def test_matmul_relu_cat(self):