From 5c1d21349e8af3fd6edd99fb954524b1150affcc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 3 Feb 2026 05:51:25 -0500 Subject: [PATCH] viz: profiler command line tool (#14515) --- extra/viz/cli.py | 36 ++++++++++++++++++++++++++++++++++-- test/null/test_viz.py | 4 ++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/extra/viz/cli.py b/extra/viz/cli.py index 846717be3b..79174d4fd8 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -3,7 +3,8 @@ import argparse, pathlib from typing import Iterator from tinygrad.viz import serve as viz from tinygrad.uop.ops import RewriteTrace -from tinygrad.helpers import temp, ansistrip, colored +from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen +from test.null.test_viz import load_profile def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip(val["name"]) == arg @@ -27,11 +28,42 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Select a kernel by name (optional name, default: only list names)') parser.add_argument('--select', type=str, default=None, metavar="NAME", - help='Select an item within the chosen kernel (optional name, default: only list names)') + help='Rewrites: Select an item within the chosen kernel (optional name, default: only list names)') + parser.add_argument('--profile', action="store_true", help="View profiling trace (default: views rewrites)") + parser.add_argument('--device', type=str, default=None, metavar="NAME", help="Profile only: Select a device (default: prints all devices)") args = parser.parse_args() viz.trace = viz.load_pickle(pathlib.Path(temp("rewrites.pkl", append_user=True)), default=RewriteTrace([], [], {})) viz.ctxs = viz.get_rewrites(viz.trace) + + if args.profile: + from tabulate import tabulate + profile = load_profile(viz.load_pickle(pathlib.Path(temp("profile.pkl", append_user=True)), default=[])) + agg, total, n = {}, 0, 0 + for k,v in profile["layout"].items(): + if not optional_eq({"name":k}, args.device): continue + print(k) + if args.device is None: continue + for e in v.get("events", []): + et = e["dur"]*1e-6 + if args.kernel is not None: + if ansistrip(e["name"]) == args.kernel and n < 10: + ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" + name = e["name"]+(" " * (46 - ansilen(e["name"]))) + print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e['fmt'].replace('\n', ' | ')+" ") + n += 1 + else: + a = agg.setdefault(e["name"], [0.0, 0]) + a[0] += et + a[1] += 1 + total += et + + if agg: + rows = [[n, t, time_to_str(t, w=9), t / c if c else 0.0, c, (t / total * 100.0) if total else 0.0] for n, (t, c) in agg.items()] + rows.sort(key=lambda r: r[1], reverse=True) + print(tabulate([[r[0], r[2], r[4], f"{r[5]:.2f}%"] for r in rows[:30]], headers=["name", "total", "count", "pct"], tablefmt="github")) + exit(0) + for k in viz.ctxs: if not optional_eq(k, args.kernel): continue print(k["name"]) diff --git a/test/null/test_viz.py b/test/null/test_viz.py index 5868c1cafb..bc1a57828e 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -353,8 +353,8 @@ def load_profile(lst:list[ProfileEvent]) -> dict: event_type, event_count = u("