From dcc2a5d23b20eef40d33650c736373919d449079 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 27 Mar 2026 21:46:39 +0200 Subject: [PATCH] viz/cli: simplify to --source and --item flags (#15510) * viz/cli: simplify to --source and --item flags * update viz cli test --- extra/viz/README | 18 +++++++++--------- extra/viz/cli.py | 30 +++++++++++++----------------- test/amd/test_sqttmap.py | 2 +- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/extra/viz/README b/extra/viz/README index 7a16c9e8f3..eab4e7139b 100644 --- a/extra/viz/README +++ b/extra/viz/README @@ -11,18 +11,18 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server ## Inspect runtime profiling -Use `extra/viz/cli.py --profile` to list all traced devices. +Use `extra/viz/cli.py --profile` to list all sources. -List top slowest kernels on a device: `--profile --device "AMD" | head 10` -List samples of a kernel on a device: `--profile --device "AMD" --kernel E_3 | head 4` +List top slowest kernels on a source: `--profile -s "AMD" | head 10` +List samples of a kernel on a source: `--profile -s "AMD" -i E_3 | head 4` ## Inspect codegen and PatternMatcher -Use `extra/viz/cli.py --rewrites` to list all traced kernels. +Use `extra/viz/cli.py --rewrites` to list all sources. -List all codegen steps for a kernel: `--rewrites --kernel E_3` -Get source code: `--rewrites --kernel E_3 --select "View Source"` -Inspect a graph rewrite: `--rewrites --kernel E_3 --select "initial symbolic"` +List all codegen steps for a kernel: `--rewrites -s E_3` +Get source code: `--rewrites -s E_3 -i "View Source"` +Inspect a graph rewrite: `--rewrites -s E_3 -i "initial symbolic"` # SQTT tracing @@ -31,12 +31,12 @@ Supported on AMD for RDNA3 and RDNA4 (best) and CDNA (developing). Flags: VIZ=-2 to save SQTT trace to a file. VIZ=2 also launches a web server. View other flags in tinygrad/runtime/ops_amd.py to configure SQTT as needed. Use `extra/viz/cli.py --profile | grep SQTT` to view all available SQTT traces. -You can select a specific trace option with --device, Example workflow: +You can select a specific trace with --source, Example workflow: ```bash # Run amd_asm_matmul with VIZ=-2 to capture the trace VIZ=-2 python extra/gemm/amd_asm_matmul.py # View barriers -extra/viz/cli.py --profile --device "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10 +extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10 ``` diff --git a/extra/viz/cli.py b/extra/viz/cli.py index 79d6e1283e..a9af99553b 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -71,14 +71,14 @@ def main(args) -> None: viz.load_amd_counters(viz.ctxs, profile_data) counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"] if s["name"].startswith("PKTS")} - if args.device is None: - print("Select a device:") + if args.source is None: + print("Available sources:") for k in (*profile["layout"], *counters): print(f" {format_colored(k)}") return None # ** SQTT printer - if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None: + if args.source is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.source), None)) is not None: # modern terminals support 24-bit color def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m" WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'), @@ -108,13 +108,13 @@ def main(args) -> None: # ** Profiler printer agg, total, n = {}, 0, 0 for k,v in profile["layout"].items(): - if not optional_eq({"name":k}, args.device): continue + if not optional_eq({"name":k}, args.source): continue print(f" {format_colored(k)}") - if args.device is None: continue + if args.source is None: continue for e in v.get("events", []): et = e["dur"]*1e-6 - if args.kernel is not None: - if optional_eq(e, args.kernel) and n < 10: + if args.item is not None: + if optional_eq(e, args.item) and n < 10: ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" name = e["name"]+(" " * (46 - ansilen(e["name"]))) print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ") @@ -132,13 +132,13 @@ def main(args) -> None: # ** Graph rewrites printer for k in viz.ctxs: - if not optional_eq(k, args.kernel): continue + if not optional_eq(k, args.source): continue print(k["name"]) - if args.kernel is None: continue + if args.source is None: continue for s in k["steps"]: - if not optional_eq(s, args.select): continue + if not optional_eq(s, args.item): continue print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else '')) - if args.select is not None: print_data(viz.get_render(s['query'])) + if args.item is not None: print_data(viz.get_render(s['query'])) def get_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() @@ -146,13 +146,9 @@ def get_arg_parser() -> argparse.ArgumentParser: g_mode.add_argument("--profile", action="store_true", help="View profile trace") g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace") g_common = parser.add_argument_group("common options") - g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)") + g_common.add_argument("-s", "--source", type=str, default=None, metavar="NAME", help="Select a data source (default: list sources)") + g_common.add_argument("-i", "--item", type=str, default=None, metavar="NAME", help="Select an item within the source (default: list items)") g_common.add_argument("--no-color", action="store_true", help="Disable colored output") - g_profile = parser.add_argument_group("profile options") - g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)") - g_rewrites = parser.add_argument_group("rewrites options") - g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME", - help="Select an item within the chosen kernel (optional name, default: only list names)") parser.add_argument("--profile-path", type=pathlib.Path, metavar="PATH", help="Path to profile (optional file, default: latest profile)", default=pathlib.Path(temp("profile.pkl", append_user=True))) parser.add_argument("--rewrites-path", type=pathlib.Path, metavar="PATH", help="Path to rewrites (optional file, default: latest rewrites)", diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index b6b6de8c14..cbfa6f3a62 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -122,7 +122,7 @@ class TestSQTTMapBase(unittest.TestCase): out = run_cli("--profile", "--profile-path", str(pkl_path)) sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l] for name in sqtt_traces: - out = run_cli("--profile", "--profile-path", str(pkl_path), "--device", name) + out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", name) lines = out.split("\n") self.assertIn("Clk", lines[0]) for r in lines[2:]: