diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 0d3bb8d8da..2aee009083 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -1,6 +1,6 @@ import unittest, decimal, json from tinygrad.dtype import dtypes -from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat +from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat, Ops from tinygrad.codegen.symbolic import symbolic from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry @@ -14,6 +14,10 @@ inner_rewrite = TrackedPatternMatcher([ (UPat.cvar("x"), lambda x: None if x.dtype == dtypes.float32 else UOp.const(dtypes.float32, x.arg)), ]) +l2 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=2, name="x"), lambda x: x.replace(arg=3))]) +l1 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=1, name="x"), lambda x: graph_rewrite(x.replace(arg=2), l2))]) +l0 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=0, name="x"), lambda x: graph_rewrite(x.replace(arg=1), l1))]) + class TestViz(unittest.TestCase): def setUp(self): # clear the global context @@ -170,10 +174,24 @@ class TestViz(unittest.TestCase): self.assertEqual(len(contexts), 1) tracked = contexts[0] self.assertEqual(len(tracked), 3) + self.assertEqual(tracked[0].depth, 0) + self.assertEqual(tracked[1].depth, 1) + self.assertEqual(tracked[2].depth, 1) # NOTE: this is sorted by the time called, maybe it should be by depth self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"]) self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1]) + def test_depth_level(self): + @track_rewrites(named=True) + def fxn(u:UOp): return graph_rewrite(u, l0) + ret = fxn(UOp(Ops.CUSTOM, arg=0)) + assert ret is UOp(Ops.CUSTOM, arg=3) + self.assertEqual(len(contexts), 1) + tracked = contexts[0] + self.assertEqual(tracked[0].depth, 0) + self.assertEqual(tracked[1].depth, 1) + self.assertEqual(tracked[2].depth, 2) + def test_shape_label(self): a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,)) b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4b01e19896..40714ba582 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -879,6 +879,7 @@ class TrackedGraphRewrite: bottom_up: bool matches: list[tuple[UOp, UOp, UPat]] # before+after of all the matches name: str|None + depth: int tracked_keys:list[Any] = [] tracked_ctxs:list[list[TrackedGraphRewrite]] = [] _name_cnt:dict[str, int] = {} @@ -900,7 +901,8 @@ def track_matches(func): def _track_func(*args, **kwargs): if tracking:=(TRACK_MATCH_STATS >= 2 and tracked_ctxs): loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno) - tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False), [], kwargs.get("name", None))) + depth = len(active_rewrites) + tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False),[], kwargs.get("name", None), depth)) active_rewrites.append(ctx) ret = func(*args, **kwargs) if tracking: active_rewrites.pop() diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index bc17d635c9..fac5b960e3 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -297,6 +297,7 @@ async function main() { const inner = ul.appendChild(document.createElement("ul")); if (i === currentKernel && j === currentUOp) inner.className = "active"; inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]} - ${u.match_count}`; + inner.style.marginLeft = `${8*u.depth}px`; inner.style.display = i === currentKernel && expandKernel ? "block" : "none"; inner.onclick = (e) => { e.stopPropagation(); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 0444424c86..e79371630e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -31,13 +31,14 @@ class GraphRewriteMetadata(TypedDict): code_line: str # source code calling graph_rewrite kernel_code: str|None # optionally render the final kernel code name: str|None # optional name of the rewrite + depth: int # depth if it's a subrewrite @functools.cache def render_program(k:Kernel): return k.opts.render(k.uops) def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata: return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(), - "kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name} + "kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name, "depth":v.depth} def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]: return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]