mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: support nested track_rewrites (#15454)
* simple test * stack active groups
This commit is contained in:
@@ -16,7 +16,7 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
|
||||
return sink
|
||||
|
||||
# real VIZ=1 loads the trace from a file, we just keep it in memory for tests
|
||||
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace
|
||||
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
|
||||
from tinygrad.viz import serve
|
||||
serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields)
|
||||
from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json
|
||||
@@ -29,7 +29,7 @@ def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
|
||||
class BaseTestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clear the global context
|
||||
for lst in [tracked_keys, tracked_ctxs, active_rewrites, _name_cnt]: lst.clear()
|
||||
for lst in [tracked_keys, tracked_ctxs, active_rewrites, active_group, _name_cnt]: lst.clear()
|
||||
Buffer.profile_events.clear()
|
||||
cpu_events.clear()
|
||||
self.tms = TRACK_MATCH_STATS.value
|
||||
@@ -121,6 +121,26 @@ class TestViz(BaseTestViz):
|
||||
# NOTE: names from TracingKey do not get deduped
|
||||
self.assertEqual(lst[0]["name"], "custom_name")
|
||||
|
||||
def test_nested_track_rewrites(self):
|
||||
@track_rewrites(name=lambda x,ret: TracingKey(f"inner fxn for {x.render()}", (ret,)))
|
||||
def inner(x:UOp): return graph_rewrite(x, PatternMatcher([]), name="each")
|
||||
@track_rewrites(name=lambda *args,ret: f"outer rewrite of {len(args)} inputs")
|
||||
def outer(*xs:tuple[UOp, ...]): return graph_rewrite(UOp.sink(*[inner(x) for x in xs]), PatternMatcher([]), name="all")
|
||||
items = ["a", "b", "c"]
|
||||
outer(*[UOp.variable(x, 1, 10) for x in items])
|
||||
lst = get_viz_list()
|
||||
# inner calls fall outside the outer call
|
||||
self.assertEqual(len(lst), len(items)+1)
|
||||
self.assertEqual(lst[0]["name"], f"outer rewrite of {len(items)} inputs n1")
|
||||
steps = lst[0]["steps"]
|
||||
self.assertEqual(len(steps), 1)
|
||||
self.assertEqual(steps[0]["name"], "all")
|
||||
for i in range(len(items)):
|
||||
self.assertEqual(lst[i+1]["name"], f"inner fxn for {items[i]}")
|
||||
steps = lst[i+1]["steps"]
|
||||
self.assertEqual(len(steps), 1)
|
||||
self.assertEqual(steps[0]["name"], "each")
|
||||
|
||||
def test_profile_matches(self):
|
||||
@profile_matches
|
||||
def nested_function(u:UOp):
|
||||
|
||||
@@ -1227,17 +1227,22 @@ def add_trace_group(kt:TracingKey) -> None:
|
||||
tracked_keys.append(kt)
|
||||
tracked_ctxs.append([])
|
||||
|
||||
active_group:list[int] = []
|
||||
def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=False):
|
||||
def _decorator(func):
|
||||
def __wrapper(*args, **kwargs):
|
||||
fn = key = func.__name__
|
||||
if TRACK_MATCH_STATS >= 2: add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,)))
|
||||
idx = -1
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,)))
|
||||
active_group.append(idx:=len(tracked_keys)-1)
|
||||
with cpu_profile(key, "TINY") as e:
|
||||
ret = func(*args, **kwargs)
|
||||
if TRACK_MATCH_STATS >= 2: active_group.pop()
|
||||
if TRACK_MATCH_STATS >= 2 and callable(name):
|
||||
name_ret = name(*args, **kwargs, ret=ret)
|
||||
assert isinstance(name_ret, (TracingKey, str)), f"name function returned {type(name_ret)}"
|
||||
tracked_keys[-1] = k = TracingKey(n:=tracked_keys[-1].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret
|
||||
tracked_keys[idx] = k = TracingKey(n:=tracked_keys[idx].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret
|
||||
e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys)
|
||||
if CAPTURE_PROCESS_REPLAY and replay:
|
||||
# find the unittest frame we're capturing in
|
||||
@@ -1260,7 +1265,8 @@ def profile_matches(fxn:Callable):
|
||||
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
|
||||
depth = len(active_rewrites)
|
||||
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
|
||||
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
|
||||
dest_group = active_group[-1] if active_group else len(tracked_ctxs)-1
|
||||
tracked_ctxs[dest_group].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
|
||||
active_rewrites.append(ctx)
|
||||
with cpu_profile(name, "TINY"):
|
||||
ret = fxn(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user