diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 531c25f5eb..c05674956c 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,7 +1,7 @@ import unittest from tinygrad import Tensor, nn from tinygrad.helpers import RANGEIFY, Context, GlobalCounters -from tinygrad.uop.ops import UOp +from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops @unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") class TestRangeifyAssign(unittest.TestCase): @@ -300,6 +300,21 @@ class TestOuterworld(unittest.TestCase): o.contiguous(i).realize() self.assertTrue((t==o).all().item()) +from tinygrad.schedule.rangeify import pm_rangeify, RangeifyContext +class TestRangeifyPM(unittest.TestCase): + @unittest.expectedFailure + def test_reshape_match(self): + def proc(a:Tensor): + sink = a.uop.sink() + pm_realize = PatternMatcher([(UPat(Ops.CONTIGUOUS, name="x"), lambda x: x.replace(op=Ops.REALIZE))]) + sink = graph_rewrite(sink, pm_realize) + return graph_rewrite(sink, pm_rangeify, ctx=RangeifyContext()) + a = Tensor.empty(10*10).reshape(10, 10).contiguous().pad(((0,0),(0,1))).contiguous() + b = Tensor.empty(10*10).reshape(10, 10).contiguous().reshape(100).reshape(10, 10).pad(((0,0),(0,1))).contiguous() + sink1 = proc(a) + sink2 = proc(b) + self.assertIs(sink1, sink2) + class TestRangeifyEdgeCase(unittest.TestCase): def test_matmul_relu_cat(self): a = Tensor.ones(100, 512).contiguous().realize()