From 652bab8aad55d3ead46e5df252fb49b0d5497965 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:01:30 +0200 Subject: [PATCH] viz: support nested track_rewrites (#15454) * simple test * stack active groups --- test/null/test_viz.py | 24 ++++++++++++++++++++++-- tinygrad/uop/ops.py | 12 +++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/test/null/test_viz.py b/test/null/test_viz.py index d840a2ea9b..1d190380d5 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -16,7 +16,7 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non return sink # real VIZ=1 loads the trace from a file, we just keep it in memory for tests -from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace +from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace from tinygrad.viz import serve serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields) from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json @@ -29,7 +29,7 @@ def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]: class BaseTestViz(unittest.TestCase): def setUp(self): # clear the global context - for lst in [tracked_keys, tracked_ctxs, active_rewrites, _name_cnt]: lst.clear() + for lst in [tracked_keys, tracked_ctxs, active_rewrites, active_group, _name_cnt]: lst.clear() Buffer.profile_events.clear() cpu_events.clear() self.tms = TRACK_MATCH_STATS.value @@ -121,6 +121,26 @@ class TestViz(BaseTestViz): # NOTE: names from TracingKey do not get deduped self.assertEqual(lst[0]["name"], "custom_name") + def test_nested_track_rewrites(self): + @track_rewrites(name=lambda x,ret: TracingKey(f"inner fxn for {x.render()}", (ret,))) + def inner(x:UOp): return graph_rewrite(x, PatternMatcher([]), name="each") + @track_rewrites(name=lambda *args,ret: f"outer rewrite of {len(args)} inputs") + def outer(*xs:tuple[UOp, ...]): return graph_rewrite(UOp.sink(*[inner(x) for x in xs]), PatternMatcher([]), name="all") + items = ["a", "b", "c"] + outer(*[UOp.variable(x, 1, 10) for x in items]) + lst = get_viz_list() + # inner calls fall outside the outer call + self.assertEqual(len(lst), len(items)+1) + self.assertEqual(lst[0]["name"], f"outer rewrite of {len(items)} inputs n1") + steps = lst[0]["steps"] + self.assertEqual(len(steps), 1) + self.assertEqual(steps[0]["name"], "all") + for i in range(len(items)): + self.assertEqual(lst[i+1]["name"], f"inner fxn for {items[i]}") + steps = lst[i+1]["steps"] + self.assertEqual(len(steps), 1) + self.assertEqual(steps[0]["name"], "each") + def test_profile_matches(self): @profile_matches def nested_function(u:UOp): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 4f9f270874..8b25492cfe 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1227,17 +1227,22 @@ def add_trace_group(kt:TracingKey) -> None: tracked_keys.append(kt) tracked_ctxs.append([]) +active_group:list[int] = [] def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=False): def _decorator(func): def __wrapper(*args, **kwargs): fn = key = func.__name__ - if TRACK_MATCH_STATS >= 2: add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,))) + idx = -1 + if TRACK_MATCH_STATS >= 2: + add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,))) + active_group.append(idx:=len(tracked_keys)-1) with cpu_profile(key, "TINY") as e: ret = func(*args, **kwargs) + if TRACK_MATCH_STATS >= 2: active_group.pop() if TRACK_MATCH_STATS >= 2 and callable(name): name_ret = name(*args, **kwargs, ret=ret) assert isinstance(name_ret, (TracingKey, str)), f"name function returned {type(name_ret)}" - tracked_keys[-1] = k = TracingKey(n:=tracked_keys[-1].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret + tracked_keys[idx] = k = TracingKey(n:=tracked_keys[idx].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys) if CAPTURE_PROCESS_REPLAY and replay: # find the unittest frame we're capturing in @@ -1260,7 +1265,8 @@ def profile_matches(fxn:Callable): loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno) depth = len(active_rewrites) if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}")) - tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False))) + dest_group = active_group[-1] if active_group else len(tracked_ctxs)-1 + tracked_ctxs[dest_group].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False))) active_rewrites.append(ctx) with cpu_profile(name, "TINY"): ret = fxn(*args, **kwargs)