Add pattern matcher tests, move uop transforms from assembly to pattern (#4056)

matcher
This commit is contained in:
Szymon Ożóg
2024-04-03 18:06:43 +02:00
committed by GitHub
parent 1ea8fcbe1b
commit e5a9bff899
3 changed files with 93 additions and 20 deletions

View File

@@ -67,6 +67,27 @@ class PatternMatcher:
if _match(uop, p, store): return fxn(**store)
return None
def rewrite_graph(self, uops: UOpGraph):
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
if rew := self.rewrite(u): replace[u] = rew
for o,n in replace.items():
queue = [n]
while queue:
if all([qq in uops.uops for qq in queue[-1].vin]):
q = queue.pop()
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([0]+[uops.uops.index(vv) for vv in q.vin])+1)
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
if not any([o in u.vin for u in uops]): uops.uops.remove(o)
constant_folder = PatternMatcher([
# const rules
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),

View File

@@ -1,4 +1,4 @@
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
import functools, struct
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
@@ -76,25 +76,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
if rew := matcher.rewrite(u): replace[u] = rew
for o,n in replace.items():
queue = [n]
while queue:
if all([qq in uops.uops for qq in queue[-1].vin]):
q = queue.pop()
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1)
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
uops.uops.remove(o)
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))