test_reshape_match should match (#12479)

This commit is contained in:
George Hotz
2025-10-07 16:07:21 +08:00
committed by GitHub
parent fe774a4319
commit 75ce11593c

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor, nn
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
from tinygrad.uop.ops import UOp
from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestRangeifyAssign(unittest.TestCase):
@@ -300,6 +300,21 @@ class TestOuterworld(unittest.TestCase):
o.contiguous(i).realize()
self.assertTrue((t==o).all().item())
from tinygrad.schedule.rangeify import pm_rangeify, RangeifyContext
class TestRangeifyPM(unittest.TestCase):
@unittest.expectedFailure
def test_reshape_match(self):
def proc(a:Tensor):
sink = a.uop.sink()
pm_realize = PatternMatcher([(UPat(Ops.CONTIGUOUS, name="x"), lambda x: x.replace(op=Ops.REALIZE))])
sink = graph_rewrite(sink, pm_realize)
return graph_rewrite(sink, pm_rangeify, ctx=RangeifyContext())
a = Tensor.empty(10*10).reshape(10, 10).contiguous().pad(((0,0),(0,1))).contiguous()
b = Tensor.empty(10*10).reshape(10, 10).contiguous().reshape(100).reshape(10, 10).pad(((0,0),(0,1))).contiguous()
sink1 = proc(a)
sink2 = proc(b)
self.assertIs(sink1, sink2)
class TestRangeifyEdgeCase(unittest.TestCase):
def test_matmul_relu_cat(self):
a = Tensor.ones(100, 512).contiguous().realize()