diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py new file mode 100644 index 0000000000..2ea730ca8d --- /dev/null +++ b/test/test_pattern_matcher.py @@ -0,0 +1,70 @@ +import unittest +from tinygrad.dtype import dtypes +from tinygrad.ops import BinaryOps +from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp + +class TestPatternMatcher(unittest.TestCase): + def test_simple_match(self): + matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, lambda x: x)]) + c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + c2 = UOp(UOps.CONST, dtypes.int, arg=1) + self.assertEqual(matcher.rewrite(c1), c1) + self.assertEqual(matcher.rewrite(c2), None) + + def test_vin_one(self): + matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.CONST})}, lambda x: x)]) + c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) + c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + self.assertEqual(matcher.rewrite(c3), c3) + self.assertEqual(matcher.rewrite(c2), None) + + def test_vin_permutations(self): + matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":[{"uop": UOps.CONST}, {"uop": UOps.ALU}]}, lambda x: x)]) + c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) + c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + c4 = UOp(UOps.ALU, dtypes.float, (c3,c2), BinaryOps.ADD) + c5 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) + c6 = UOp(UOps.ALU, dtypes.float, (c3,c4), BinaryOps.ADD) + self.assertEqual(matcher.rewrite(c3), None) + self.assertEqual(matcher.rewrite(c4), c4) + self.assertEqual(matcher.rewrite(c5), c5) + self.assertEqual(matcher.rewrite(c6), None) + + def test_vin_repeat(self): + matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":{"uop": UOps.CONST}}, lambda x: x)]) + c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) + c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + c4 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) + self.assertEqual(matcher.rewrite(c3), c3) + self.assertEqual(matcher.rewrite(c4), None) + + def test_rewrite_graph_folds(self): + uops = UOpGraph() + uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False) + matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, + lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x,x), BinaryOps.ADD),)))]) + matcher.rewrite_graph(uops) + def _to_key(u): return (u.uop, u.dtype, u.arg) + self.assertEqual(_to_key(UOp(UOps.CONST, dtypes.int, arg=4)), _to_key(uops.uops[-1])) + + def test_rewrite_graph_adds(self): + uops = UOpGraph() + uops.add(UOps.CONST, dtypes.int, arg=2, simplify=False) + matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.int}, + lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))]) + matcher.rewrite_graph(uops) + uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE})) + e1 = UOp(UOps.CONST, dtypes.int, arg=2) + e2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) + e3 = UOp(UOps.STORE, dtypes.int, (e2,e1)) + def _to_key(u): return (u.uop, u.dtype, u.arg) + self.assertEqual(_to_key(e1), _to_key(uops.uops[0])) + self.assertEqual(_to_key(e2), _to_key(uops.uops[1])) + self.assertEqual(_to_key(e3), _to_key(uops.uops[2])) + self.assertEqual(len(uops.uops), 3) + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 09156c7a87..6006ad7512 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -67,6 +67,27 @@ class PatternMatcher: if _match(uop, p, store): return fxn(**store) return None + def rewrite_graph(self, uops: UOpGraph): + replace: Dict[UOp, UOp] = {} + seen: Set[UOp] = set() + for u in uops: + if u in seen: continue + seen.add(u) + for o,n in replace.items(): + if o in u.vin and u is not n: + u.vin = tuple(n if x == o else x for x in u.vin) + if rew := self.rewrite(u): replace[u] = rew + + for o,n in replace.items(): + queue = [n] + while queue: + if all([qq in uops.uops for qq in queue[-1].vin]): + q = queue.pop() + new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([0]+[uops.uops.index(vv) for vv in q.vin])+1) + for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin) + else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops]) + if not any([o in u.vin for u in uops]): uops.uops.remove(o) + constant_folder = PatternMatcher([ # const rules ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)), diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index f7915c3067..e56014e1e4 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -1,4 +1,4 @@ -from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set +from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple import functools, struct from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp @@ -76,25 +76,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: # here we do a pretransform on UOps to fix some shortcomings of PTX # all uops must be a register - replace: Dict[UOp, UOp] = {} - seen: Set[UOp] = set() - for u in uops: - if u in seen: continue - seen.add(u) - for o,n in replace.items(): - if o in u.vin and u is not n: - u.vin = tuple(n if x == o else x for x in u.vin) - if rew := matcher.rewrite(u): replace[u] = rew - - for o,n in replace.items(): - queue = [n] - while queue: - if all([qq in uops.uops for qq in queue[-1].vin]): - q = queue.pop() - new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1) - for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin) - else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops]) - uops.uops.remove(o) + matcher.rewrite_graph(uops) for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops) uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))