mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
more test_pattern_matcher fixups (#5714)
This commit is contained in:
@@ -3,7 +3,7 @@ from test.helpers import TestUOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, constant_folder
|
||||
from tinygrad.codegen.uopgraph import constant_folder
|
||||
|
||||
class TestPatternMatcher(TestUOps):
|
||||
def test_simple_match(self):
|
||||
@@ -49,7 +49,7 @@ class TestPatternMatcher(TestUOps):
|
||||
|
||||
def test_filter_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)), name="x"),
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)], name="x"),
|
||||
lambda x,c: x if c.arg in {1, -1} else None)
|
||||
])
|
||||
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
@@ -58,14 +58,13 @@ class TestPatternMatcher(TestUOps):
|
||||
c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
|
||||
c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
|
||||
c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
|
||||
# c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
|
||||
# c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
|
||||
c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
|
||||
c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
# TODO: match these
|
||||
# self.assertEqual(matcher.rewrite(c4), c4)
|
||||
# self.assertEqual(matcher.rewrite(c5), c5)
|
||||
self.assertEqual(matcher.rewrite(c4), c4)
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
|
||||
@@ -135,10 +134,10 @@ class TestPatternMatcher(TestUOps):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
|
||||
#c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
|
||||
#self.assertEqual(matcher.rewrite(c4), c4)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
self.assertEqual(matcher.rewrite(c6), c6)
|
||||
|
||||
@@ -153,37 +152,6 @@ class TestPatternMatcher(TestUOps):
|
||||
self.assertIsNotNone(matcher.rewrite(u1))
|
||||
self.assertIsNotNone(matcher.rewrite(u2))
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_rewrite_graph_folds(self):
|
||||
uops = UOpGraph()
|
||||
UOp(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
# TODO: fix this. it's 2 now
|
||||
# self.assertEqual(len(uops.uops), 1)
|
||||
self.assertEqual(len(uops.uops), 2)
|
||||
self.assert_equiv_uops(UOp(UOps.CONST, dtypes.int, arg=4), uops.uops[-1])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_rewrite_graph_adds(self):
|
||||
uops = UOpGraph()
|
||||
UOp(UOps.CONST, dtypes.int, arg=2, simplify=False)
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
|
||||
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
uops.remove_childless(set(x for x in uops if x.op in {UOps.STORE}))
|
||||
|
||||
self.assertEqual(len(uops.uops), 3)
|
||||
|
||||
e1 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
e2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
e3 = UOp(UOps.STORE, dtypes.int, (e2,e1))
|
||||
|
||||
self.assert_equiv_uops(e1, uops.uops[0])
|
||||
self.assert_equiv_uops(e2, uops.uops[1])
|
||||
self.assert_equiv_uops(e3, uops.uops[2])
|
||||
|
||||
def _assert_eq_upat(self, a:UPat, b:UPat):
|
||||
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
|
||||
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
|
||||
@@ -196,6 +164,8 @@ class TestPatternMatcher(TestUOps):
|
||||
|
||||
def test_upat_str(self):
|
||||
dtypes._float2 = dtypes.float.vec(2)
|
||||
dtypes._float4 = dtypes.float.vec(4)
|
||||
dtypes._float8 = dtypes.float.vec(8)
|
||||
upat = UPat(UOps.CONST, name="x", dtype=dtypes.float)
|
||||
assert str(upat) == str(eval(str(upat)))
|
||||
evpat:UPat = eval(repr(UPat(src = [UPat(name='a'), UPat(name='b')])))
|
||||
@@ -203,9 +173,9 @@ class TestPatternMatcher(TestUOps):
|
||||
for i in range(20): upat = UPat(UOps.ALU, name="x", src=[upat, upat], arg=BinaryOps.ADD)
|
||||
assert len(str(upat)) < 10_000
|
||||
assert str(eval(str(upat))) == str(upat)
|
||||
for rule in constant_folder.pdict.values():
|
||||
pat = rule[0][0]
|
||||
self._assert_eq_upat(pat, eval(str(pat)))
|
||||
for rules in constant_folder.pdict.values():
|
||||
for pat,_ in rules:
|
||||
self._assert_eq_upat(pat, eval(str(pat)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user