mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
failed test in test_pattern_matcher (#4080)
something about the PTX rewrite is incorrect that it has duplicated rewritten uops
This commit is contained in:
@@ -4,6 +4,12 @@ from tinygrad.ops import BinaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
|
||||
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
|
||||
self.assertEqual(uop1.uop, uop2.uop)
|
||||
self.assertEqual(uop1.dtype, uop2.dtype)
|
||||
self.assertEqual(uop1.arg, uop2.arg)
|
||||
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
@@ -56,10 +62,12 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
uops = UOpGraph()
|
||||
uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float},
|
||||
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x,x), BinaryOps.ADD),)))])
|
||||
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
def _to_key(u): return (u.uop, u.dtype, u.arg)
|
||||
self.assertEqual(_to_key(UOp(UOps.CONST, dtypes.int, arg=4)), _to_key(uops.uops[-1]))
|
||||
# TODO: fix this. it's 2 now
|
||||
# self.assertEqual(len(uops.uops), 1)
|
||||
self.assertEqual(len(uops.uops), 2)
|
||||
self.assert_equiv_uops(UOp(UOps.CONST, dtypes.int, arg=4), uops.uops[-1])
|
||||
|
||||
def test_rewrite_graph_adds(self):
|
||||
uops = UOpGraph()
|
||||
@@ -68,14 +76,16 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE}))
|
||||
|
||||
self.assertEqual(len(uops.uops), 3)
|
||||
|
||||
e1 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
e2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
e3 = UOp(UOps.STORE, dtypes.int, (e2,e1))
|
||||
def _to_key(u): return (u.uop, u.dtype, u.arg)
|
||||
self.assertEqual(_to_key(e1), _to_key(uops.uops[0]))
|
||||
self.assertEqual(_to_key(e2), _to_key(uops.uops[1]))
|
||||
self.assertEqual(_to_key(e3), _to_key(uops.uops[2]))
|
||||
self.assertEqual(len(uops.uops), 3)
|
||||
|
||||
self.assert_equiv_uops(e1, uops.uops[0])
|
||||
self.assert_equiv_uops(e2, uops.uops[1])
|
||||
self.assert_equiv_uops(e3, uops.uops[2])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user