mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
RewriteContext [run_process_replay] (#6428)
This commit is contained in:
@@ -721,15 +721,17 @@ if TRACK_MATCH_STATS:
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
||||
nodes: Dict[Tuple, UOp] = {}
|
||||
replace: Dict[UOp, UOp] = {}
|
||||
def __inner_rewrite(n:UOp) -> UOp:
|
||||
if rn := replace.get(n): return rn
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
||||
if found := nodes.get(replace_source): replace[n] = found
|
||||
class RewriteContext:
|
||||
def __init__(self, pm):
|
||||
self.pm: PatternMatcher = pm
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
self.replace: Dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
if rn := self.replace.get(n): return rn
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg)
|
||||
if found := self.nodes.get(replace_source): self.replace[n] = found
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
|
||||
return found
|
||||
return __inner_rewrite(sink)
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)
|
||||
|
||||
Reference in New Issue
Block a user