From 05e02ddfb3a22b44cfced48c1153ce05c8c48c41 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 25 Jul 2024 13:48:52 -0400 Subject: [PATCH] fixup test_pattern_matcher (#5712) --- test/test_pattern_matcher.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index 824148c107..1b6c6b1d82 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -2,7 +2,7 @@ import unittest, itertools from test.helpers import TestUOps from tinygrad.dtype import dtypes from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 -from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat, _match +from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat from tinygrad.codegen.uopgraph import UOpGraph, constant_folder class TestPatternMatcher(TestUOps): @@ -47,18 +47,25 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c4), None) self.assertEqual(matcher.rewrite(c5), None) - @unittest.skip("this is not supported any more") - def test_arg_set(self): - matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)]) + def test_filter_arg(self): + matcher = PatternMatcher([ + (UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)), name="x"), + lambda x,c: x if c.arg in {1, -1} else None) + ]) y1 = UOp(UOps.CONST, dtypes.int, arg=1) y2 = UOp(UOps.CONST, dtypes.int, arg=2) y3 = UOp(UOps.CONST, dtypes.int, arg=-1) c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL) c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL) c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL) + # c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL) + # c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) self.assertEqual(matcher.rewrite(c3), c3) + # TODO: match these + # self.assertEqual(matcher.rewrite(c4), c4) + # self.assertEqual(matcher.rewrite(c5), c5) def test_dup_name(self): matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)]) @@ -77,7 +84,7 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c2), None) def test_dtype_set(self): - matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=set([dtypes.float32, dtypes.float64])), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0) c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0) @@ -87,7 +94,7 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c4), None) - def test_vin_one(self): + def test_src_one(self): matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -101,7 +108,7 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c5), None) - def test_vin_permutations(self): + def test_src_permutations(self): matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -114,7 +121,7 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c5), c5) self.assertEqual(matcher.rewrite(c6), None) - def test_vin_repeat(self): + def test_src_repeat(self): matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -140,10 +147,11 @@ class TestPatternMatcher(TestUOps): c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) u1 = (c1 + c2) + c1 u2 = (c2 + c1) + c1 - pat = UPat(UOps.ALU, src = (UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b'))) - # TODO: why is this calling a private function? - assert _match(u1, pat, {}) - assert _match(u2, pat, {}) + matcher = PatternMatcher([ + (UPat(UOps.ALU, src=[UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b) + ]) + self.assertIsNotNone(matcher.rewrite(u1)) + self.assertIsNotNone(matcher.rewrite(u2)) @unittest.skip("no longer supported") def test_rewrite_graph_folds(self):