From e7ff5102cf56bd83944abb8d5aa338e90d838779 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 5 Apr 2024 02:53:50 -0400 Subject: [PATCH] failed test in test_pattern_matcher (#4080) something about the PTX rewrite is incorrect that it has duplicated rewritten uops --- test/test_pattern_matcher.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index b61fac9e17..0006098578 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -4,6 +4,12 @@ from tinygrad.ops import BinaryOps from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp class TestPatternMatcher(unittest.TestCase): + def assert_equiv_uops(self, uop1:UOp, uop2:UOp): + # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops + self.assertEqual(uop1.uop, uop2.uop) + self.assertEqual(uop1.dtype, uop2.dtype) + self.assertEqual(uop1.arg, uop2.arg) + def test_simple_match(self): matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) @@ -56,10 +62,12 @@ class TestPatternMatcher(unittest.TestCase): uops = UOpGraph() uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False) matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, - lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x,x), BinaryOps.ADD),)))]) + lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))]) matcher.rewrite_graph(uops) - def _to_key(u): return (u.uop, u.dtype, u.arg) - self.assertEqual(_to_key(UOp(UOps.CONST, dtypes.int, arg=4)), _to_key(uops.uops[-1])) + # TODO: fix this. it's 2 now + # self.assertEqual(len(uops.uops), 1) + self.assertEqual(len(uops.uops), 2) + self.assert_equiv_uops(UOp(UOps.CONST, dtypes.int, arg=4), uops.uops[-1]) def test_rewrite_graph_adds(self): uops = UOpGraph() @@ -68,14 +76,16 @@ class TestPatternMatcher(unittest.TestCase): lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))]) matcher.rewrite_graph(uops) uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE})) + + self.assertEqual(len(uops.uops), 3) + e1 = UOp(UOps.CONST, dtypes.int, arg=2) e2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) e3 = UOp(UOps.STORE, dtypes.int, (e2,e1)) - def _to_key(u): return (u.uop, u.dtype, u.arg) - self.assertEqual(_to_key(e1), _to_key(uops.uops[0])) - self.assertEqual(_to_key(e2), _to_key(uops.uops[1])) - self.assertEqual(_to_key(e3), _to_key(uops.uops[2])) - self.assertEqual(len(uops.uops), 3) + + self.assert_equiv_uops(e1, uops.uops[0]) + self.assert_equiv_uops(e2, uops.uops[1]) + self.assert_equiv_uops(e3, uops.uops[2]) if __name__ == '__main__': unittest.main(verbosity=2)