diff --git a/test/test_viz.py b/test/test_viz.py index c982804865..2edbea3d8a 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -4,11 +4,11 @@ from tinygrad.dtype import dtypes from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, contexts, track_rewrites from tinygrad.viz.serve import get_details, get_metadata, uop_to_json -@track_rewrites() -def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx) +@track_rewrites(named=True) +def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs) -def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]: - rewrite(sink, pm, ctx) +def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]: + rewrite(sink, pm, **kwargs) assert len(contexts) == 1 assert len(contexts[0][1]) == 1 k = get_metadata(contexts)[0][0] @@ -52,7 +52,7 @@ class TestViz(unittest.TestCase): pm = PatternMatcher([ (UPat(Ops.LOAD, name="x"), store_load), ]) - uops = helper_test_viz(a+b, pm, {}) + uops = helper_test_viz(a+b, pm, ctx={}) self.assertEqual(len(uops), 2) self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {})) @@ -89,5 +89,25 @@ class TestViz(unittest.TestCase): assert not any(v[0].startswith("CONST") for v in graph.values()) assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 + def test_bottom_up_rewrite(self): + a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + n1 = a.sin() + uop = n1.sin() + pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) + ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True) + self.assertEqual(len(ret), 2) + self.assertIs(ret[0], a.sin().sqrt()) # first rewrite + self.assertIs(ret[1], a.sqrt().sqrt()) # second one + + def test_top_down_rewrite(self): + a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + n1 = a.sin() + uop = n1.sin() + pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) + # if it wasn't bottom_up, it's rewritten once + ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False) + self.assertEqual(len(ret), 1) + self.assertIs(ret[0], a.sqrt().sin()) # only rewrite + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 58e97c1ded..95322415f2 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -688,6 +688,7 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict() class TrackedRewriteContext: loc: Tuple[str, int] # location that called graph_rewrite sink: UOp # the sink passed into the rewrite + bottom_up: bool matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = [] @@ -775,7 +776,7 @@ class RewriteContext: def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp: if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: - rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) + rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up)) return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink) # ***** uop type spec ***** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 5ffc9b9625..6b4bd8b7a4 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -82,14 +82,14 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) replaces: Dict[UOp, UOp] = {} sink = ctx.sink for i,(u0,u1,upat,_) in enumerate(ctx.matches): + if ctx.bottom_up: replaces = {} # if it's bottom_up it's single pass replaces[u0] = u0 if u1 is None else u1 # if the match didn't result in a rewrite we move forward if u1 is None: continue - # first, rewrite this UOp with the current rewrite + all the seen matches before this + # first, rewrite this UOp with the current rewrite + all the matches in replaces new_sink = _replace_uop(sink, {**replaces}) # sanity check - if new_sink is sink: - raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") + if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") # update ret data g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST]) g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))