viz: support nested track_rewrites (#15454)

* simple test

* stack active groups
This commit is contained in:
qazal
2026-03-24 22:01:30 +02:00
committed by GitHub
parent 41eb2cc41b
commit 652bab8aad
2 changed files with 31 additions and 5 deletions

View File

@@ -16,7 +16,7 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
return sink
# real VIZ=1 loads the trace from a file, we just keep it in memory for tests
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
from tinygrad.viz import serve
serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields)
from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json
@@ -29,7 +29,7 @@ def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
class BaseTestViz(unittest.TestCase):
def setUp(self):
# clear the global context
for lst in [tracked_keys, tracked_ctxs, active_rewrites, _name_cnt]: lst.clear()
for lst in [tracked_keys, tracked_ctxs, active_rewrites, active_group, _name_cnt]: lst.clear()
Buffer.profile_events.clear()
cpu_events.clear()
self.tms = TRACK_MATCH_STATS.value
@@ -121,6 +121,26 @@ class TestViz(BaseTestViz):
# NOTE: names from TracingKey do not get deduped
self.assertEqual(lst[0]["name"], "custom_name")
def test_nested_track_rewrites(self):
@track_rewrites(name=lambda x,ret: TracingKey(f"inner fxn for {x.render()}", (ret,)))
def inner(x:UOp): return graph_rewrite(x, PatternMatcher([]), name="each")
@track_rewrites(name=lambda *args,ret: f"outer rewrite of {len(args)} inputs")
def outer(*xs:tuple[UOp, ...]): return graph_rewrite(UOp.sink(*[inner(x) for x in xs]), PatternMatcher([]), name="all")
items = ["a", "b", "c"]
outer(*[UOp.variable(x, 1, 10) for x in items])
lst = get_viz_list()
# inner calls fall outside the outer call
self.assertEqual(len(lst), len(items)+1)
self.assertEqual(lst[0]["name"], f"outer rewrite of {len(items)} inputs n1")
steps = lst[0]["steps"]
self.assertEqual(len(steps), 1)
self.assertEqual(steps[0]["name"], "all")
for i in range(len(items)):
self.assertEqual(lst[i+1]["name"], f"inner fxn for {items[i]}")
steps = lst[i+1]["steps"]
self.assertEqual(len(steps), 1)
self.assertEqual(steps[0]["name"], "each")
def test_profile_matches(self):
@profile_matches
def nested_function(u:UOp):

View File

@@ -1227,17 +1227,22 @@ def add_trace_group(kt:TracingKey) -> None:
tracked_keys.append(kt)
tracked_ctxs.append([])
active_group:list[int] = []
def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=False):
def _decorator(func):
def __wrapper(*args, **kwargs):
fn = key = func.__name__
if TRACK_MATCH_STATS >= 2: add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,)))
idx = -1
if TRACK_MATCH_STATS >= 2:
add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,)))
active_group.append(idx:=len(tracked_keys)-1)
with cpu_profile(key, "TINY") as e:
ret = func(*args, **kwargs)
if TRACK_MATCH_STATS >= 2: active_group.pop()
if TRACK_MATCH_STATS >= 2 and callable(name):
name_ret = name(*args, **kwargs, ret=ret)
assert isinstance(name_ret, (TracingKey, str)), f"name function returned {type(name_ret)}"
tracked_keys[-1] = k = TracingKey(n:=tracked_keys[-1].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret
tracked_keys[idx] = k = TracingKey(n:=tracked_keys[idx].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret
e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys)
if CAPTURE_PROCESS_REPLAY and replay:
# find the unittest frame we're capturing in
@@ -1260,7 +1265,8 @@ def profile_matches(fxn:Callable):
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
depth = len(active_rewrites)
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
dest_group = active_group[-1] if active_group else len(tracked_ctxs)-1
tracked_ctxs[dest_group].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
active_rewrites.append(ctx)
with cpu_profile(name, "TINY"):
ret = fxn(*args, **kwargs)