From c7b18e6108f1f875b140d1009eccad196ddcf0cd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 22 Mar 2026 17:17:05 +0200 Subject: [PATCH] viz: sqtt printer in viz/cli.py (#15411) * work * sqtt timeline in CLI * format all printers nicely * s/Showed/Printed * ansistrip * sys.exit * keep colors in list * work from amd_copy_matmul * has_more always gets returned * linter * don't print colors * more colors * wow this is so deep * work * minor details * selected * improve progress bar * remove it * 22, global_load_vaddr is so long --- .../tinybox_8xMI350X/profile.sh | 2 +- extra/viz/README | 3 +- extra/viz/cli.py | 67 ++++++++++++++++--- test/amd/test_sqttmap.py | 4 +- tinygrad/viz/serve.py | 11 +-- 5 files changed, 71 insertions(+), 16 deletions(-) diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/profile.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/profile.sh index fde461c018..0e2c61978c 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/profile.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/profile.sh @@ -2,4 +2,4 @@ export BENCHMARK=5 export EVAL_BS=0 VIZ=${VIZ:--1} examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh -extra/viz/cli.py --profile --device "AMD" --top 20 +extra/viz/cli.py --profile --device "AMD" --limit 20 diff --git a/extra/viz/README b/extra/viz/README index a7db4f9f66..f0282500ae 100644 --- a/extra/viz/README +++ b/extra/viz/README @@ -1,6 +1,7 @@ A command line tool for exploring the VIZ trace. -After running with VIZ=-1, use `extra/viz/cli.py` to explore the saved trace files. +1. Set VIZ to -1 to save the trace. +2. Use `extra/viz/cli.py` to inspect the trace files. ## Inspect runtime profiling diff --git a/extra/viz/cli.py b/extra/viz/cli.py index f8246ba8f5..1cd2fffb50 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import argparse, pathlib, sys, struct, json +import argparse, pathlib, sys, struct, json, itertools from typing import Iterator from tinygrad.viz import serve as viz from tinygrad.uop.ops import RewriteTrace @@ -65,9 +65,12 @@ if __name__ == "__main__": g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace") g_common = parser.add_argument_group("common options") g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)") + g_common.add_argument("--no-color", action="store_true", default=not (sys.stdin.isatty() and sys.stdout.isatty()), + help="Disable colored output (default: true in non-interactive mode)") g_profile = parser.add_argument_group("profile options") g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)") - g_profile.add_argument("--top", type=int, default=10, metavar="N", help="Number of top kernels to show (-1 for all, default: 10)") + g_profile.add_argument("--offset", type=int, default=0, metavar="N", help="event offset (default: 0)") + g_profile.add_argument("--limit", type=int, default=10, metavar="N", help="events to display (-1 for all, default: 10)") g_rewrites = parser.add_argument_group("rewrites options") g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME", help="Select an item within the chosen kernel (optional name, default: only list names)") @@ -83,14 +86,61 @@ if __name__ == "__main__": viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})) viz.ctxs = viz.get_rewrites(viz.trace) + def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s + if args.profile: from tabulate import tabulate - profile = decode_profile(viz.get_profile(viz.load_pickle(args.profile_path, default=[]))) + profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[]))) + viz.load_amd_counters(viz.ctxs, profile_data) + counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"] + if s["name"].startswith("PKTS")} + if args.device is None: + print("Select a device:") + for k in (*profile["layout"], *counters): + print(f" {format_colored(k)}") + sys.exit(0) + + # ** SQTT printer + if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None: + assert args.limit > 1, f"SQTT limit must be greater than 1, got {args.limit}" + sqtt_events, has_more = viz.sqtt_timeline(*sqtt_data, max_pkts=args.offset+args.limit) + sqtt_pkts = [e for e in sqtt_events if type(e).__name__ == "ProfileRangeEvent"] + pc_map = next(e.arg for e in sqtt_events if type(e).__name__ == "ProfilePointEvent" and e.key == 'pcMap') + # modern terminals support 24-bit color + def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m" + WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'), + (('STORE',), '#4fa3cc'), (('IMMEDIATE',), '#f3b44a'), (('BARRIER',), '#d00000'), (('LDS',), '#9fb4a6'), (('JUMP',), '#ffb703'), + (('JUMP_NO',), '#fb8500'), (('MESSAGE',), '#90dbf4'), (('WAVERDY',), '#1a2a2a')) + print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Info'}") + print("-" * 90) + # start from the first packet in trace, prepare packet indexes and map dispatches + pkt_idxs:dict[str, itertools.count] = {} + dispatch_to_pc:dict[str, int] = {} + for e in sqtt_pkts[:-args.limit]: + idx = next(pkt_idxs.setdefault(e.device, itertools.count())) + if e.name.ret is not None and e.name.ret.startswith("PC:"): dispatch_to_pc[f"{e.device}-{idx}"] = int(e.name.ret.replace("PC:", "")) + # start printing from the offset point + for e in sqtt_pkts[-args.limit:]: + op_name, info = e.name.display_name, e.name.ret or "" + color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None) + op_str = hex_colored(op_name, color) if color and not args.no_color else op_name + phase, pc = None, None + idx = next(pkt_idxs.setdefault(e.device, itertools.count())) + if info.startswith("PC:"): + dispatch_to_pc[f"{e.device}-{idx}"] = pc = int(info.replace("PC:", "")) + phase = "DISPATCH" + if info.startswith("LINK:"): phase, pc = "EXEC", dispatch_to_pc[info.replace("LINK:", "")] + if pc and phase: info = f"{phase:<8} 0x{pc:05x} {pc_map[pc]}" + print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}") + # note: we only print the important packets and skip the rest + if has_more: print(f"Selected packets {args.offset:,}-{args.offset + args.limit:,}. Use --offset and --limit to see others") + sys.exit(0) + + # ** Profiler printer agg, total, n = {}, 0, 0 - if args.device is None: print("Select a device:") for k,v in profile["layout"].items(): if not optional_eq({"name":k}, args.device): continue - print(f" {k}") + print(f" {format_colored(k)}") if args.device is None: continue for e in v.get("events", []): et = e["dur"]*1e-6 @@ -98,7 +148,7 @@ if __name__ == "__main__": if optional_eq(e, args.kernel) and n < 10: ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" name = e["name"]+(" " * (46 - ansilen(e["name"]))) - print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e['fmt'].replace('\n', ' | ')+" ") + print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ") n += 1 else: a = agg.setdefault(e["name"], [0.0, 0]) @@ -107,14 +157,15 @@ if __name__ == "__main__": total += et if agg and total > 0: items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True) - sel = items if args.top == -1 else items[:args.top] + sel = items if args.limit == -1 else items[args.offset:args.offset+args.limit] table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in sel] - if args.top != -1 and (other:=items[len(sel):]): + if args.limit != -1 and (other:=items[len(sel):]): other_t = total-sum(t for _, (t, _) in sel) table.append([f"Other ({len(other)} unique)", time_to_str(other_t, w=9), sum(c for _,(_,c) in other), f"{other_t/total*100.0:.2f}%"]) print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github")) sys.exit(0) + # ** Graph rewrites printer for k in viz.ctxs: if not optional_eq(k, args.kernel): continue print(k["name"]) diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index 40f6b87059..62cfacbda8 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -79,7 +79,7 @@ class TestSQTTMapBase(unittest.TestCase): if (p:=kern_events.get(event.kern)) is None: continue with self.subTest(example=name, kern=event.kern): # skip if there's no SQTT frequency data - if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue + if not (timeline:=sqtt_timeline(event.blob, p.lib, target)[0]): continue if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue mean = sum(frequency) / len(frequency) variance = sum((v - mean) ** 2 for v in frequency) / len(frequency) @@ -101,7 +101,7 @@ class TestSQTTMapBase(unittest.TestCase): for name, (events, kern_events, target) in self.examples.items(): for event in events: wave_barriers = {} - for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target): + for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target)[0]: if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e) if not wave_barriers: continue for row, events in wave_barriers.items(): diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 9069993b90..0c3c044479 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -337,7 +337,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None: steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch))) ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) -def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: +def sqtt_timeline(data:bytes, lib:bytes, target:str, max_pkts=getenv("MAX_SQTT_PKTS",50_000)) -> tuple[list[ProfileEvent], bool]: from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA, WAVEEND, CDNA_WAVEEND, WAVERDY) @@ -365,8 +365,11 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None: exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}") if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}") + has_more = False for p, info in map_insts(data, lib, target): - if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break + if len(ret) > max_pkts: + has_more = True + break if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker: pair = (p._time, p.delta) if prev_pair is None: prev_pair = pair @@ -393,7 +396,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: else: add(name.replace("_ALT", ""), p, op=name) pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()} - return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret + return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret, has_more # ** SQTT OCC only unpacks wave start, end time and SIMD location @@ -619,7 +622,7 @@ def get_render(query:str) -> dict: if fmt.startswith("prg-pkts"): ret = {} with soft_err(lambda err:ret.update(err)): - if (events:=get_profile(sqtt_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"} + if (events:=get_profile(sqtt_timeline(*data)[0], sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"} else: ret = {"src":"No SQTT trace on this SE."} return ret if fmt == "prg-sqtt":