mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 23:25:04 -05:00
Add pattern matcher tests, move uop transforms from assembly to pattern (#4056)
matcher
This commit is contained in:
@@ -67,6 +67,27 @@ class PatternMatcher:
|
||||
if _match(uop, p, store): return fxn(**store)
|
||||
return None
|
||||
|
||||
def rewrite_graph(self, uops: UOpGraph):
|
||||
replace: Dict[UOp, UOp] = {}
|
||||
seen: Set[UOp] = set()
|
||||
for u in uops:
|
||||
if u in seen: continue
|
||||
seen.add(u)
|
||||
for o,n in replace.items():
|
||||
if o in u.vin and u is not n:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
if rew := self.rewrite(u): replace[u] = rew
|
||||
|
||||
for o,n in replace.items():
|
||||
queue = [n]
|
||||
while queue:
|
||||
if all([qq in uops.uops for qq in queue[-1].vin]):
|
||||
q = queue.pop()
|
||||
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([0]+[uops.uops.index(vv) for vv in q.vin])+1)
|
||||
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
|
||||
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
|
||||
if not any([o in u.vin for u in uops]): uops.uops.remove(o)
|
||||
|
||||
constant_folder = PatternMatcher([
|
||||
# const rules
|
||||
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
|
||||
import functools, struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
@@ -76,25 +76,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
replace: Dict[UOp, UOp] = {}
|
||||
seen: Set[UOp] = set()
|
||||
for u in uops:
|
||||
if u in seen: continue
|
||||
seen.add(u)
|
||||
for o,n in replace.items():
|
||||
if o in u.vin and u is not n:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
if rew := matcher.rewrite(u): replace[u] = rew
|
||||
|
||||
for o,n in replace.items():
|
||||
queue = [n]
|
||||
while queue:
|
||||
if all([qq in uops.uops for qq in queue[-1].vin]):
|
||||
q = queue.pop()
|
||||
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1)
|
||||
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
|
||||
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
|
||||
uops.uops.remove(o)
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
|
||||
Reference in New Issue
Block a user