viz: sqtt printer in viz/cli.py (#15411)

* work

* sqtt timeline in CLI

* format all printers nicely

* s/Showed/Printed

* ansistrip

* sys.exit

* keep colors in list

* work from amd_copy_matmul

* has_more always gets returned

* linter

* don't print colors

* more colors

* wow this is so deep

* work

* minor details

* selected

* improve progress bar

* remove it

* 22, global_load_vaddr is so long
This commit is contained in:
qazal
2026-03-22 17:17:05 +02:00
committed by GitHub
parent bcc08307da
commit c7b18e6108
5 changed files with 71 additions and 16 deletions

View File

@@ -2,4 +2,4 @@
export BENCHMARK=5
export EVAL_BS=0
VIZ=${VIZ:--1} examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
extra/viz/cli.py --profile --device "AMD" --top 20
extra/viz/cli.py --profile --device "AMD" --limit 20

View File

@@ -1,6 +1,7 @@
A command line tool for exploring the VIZ trace.
After running with VIZ=-1, use `extra/viz/cli.py` to explore the saved trace files.
1. Set VIZ to -1 to save the trace.
2. Use `extra/viz/cli.py` to inspect the trace files.
## Inspect runtime profiling

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import argparse, pathlib, sys, struct, json
import argparse, pathlib, sys, struct, json, itertools
from typing import Iterator
from tinygrad.viz import serve as viz
from tinygrad.uop.ops import RewriteTrace
@@ -65,9 +65,12 @@ if __name__ == "__main__":
g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace")
g_common = parser.add_argument_group("common options")
g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)")
g_common.add_argument("--no-color", action="store_true", default=not (sys.stdin.isatty() and sys.stdout.isatty()),
help="Disable colored output (default: true in non-interactive mode)")
g_profile = parser.add_argument_group("profile options")
g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)")
g_profile.add_argument("--top", type=int, default=10, metavar="N", help="Number of top kernels to show (-1 for all, default: 10)")
g_profile.add_argument("--offset", type=int, default=0, metavar="N", help="event offset (default: 0)")
g_profile.add_argument("--limit", type=int, default=10, metavar="N", help="events to display (-1 for all, default: 10)")
g_rewrites = parser.add_argument_group("rewrites options")
g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME",
help="Select an item within the chosen kernel (optional name, default: only list names)")
@@ -83,14 +86,61 @@ if __name__ == "__main__":
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
viz.ctxs = viz.get_rewrites(viz.trace)
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
if args.profile:
from tabulate import tabulate
profile = decode_profile(viz.get_profile(viz.load_pickle(args.profile_path, default=[])))
profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[])))
viz.load_amd_counters(viz.ctxs, profile_data)
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
if s["name"].startswith("PKTS")}
if args.device is None:
print("Select a device:")
for k in (*profile["layout"], *counters):
print(f" {format_colored(k)}")
sys.exit(0)
# ** SQTT printer
if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None:
assert args.limit > 1, f"SQTT limit must be greater than 1, got {args.limit}"
sqtt_events, has_more = viz.sqtt_timeline(*sqtt_data, max_pkts=args.offset+args.limit)
sqtt_pkts = [e for e in sqtt_events if type(e).__name__ == "ProfileRangeEvent"]
pc_map = next(e.arg for e in sqtt_events if type(e).__name__ == "ProfilePointEvent" and e.key == 'pcMap')
# modern terminals support 24-bit color
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
(('STORE',), '#4fa3cc'), (('IMMEDIATE',), '#f3b44a'), (('BARRIER',), '#d00000'), (('LDS',), '#9fb4a6'), (('JUMP',), '#ffb703'),
(('JUMP_NO',), '#fb8500'), (('MESSAGE',), '#90dbf4'), (('WAVERDY',), '#1a2a2a'))
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Info'}")
print("-" * 90)
# start from the first packet in trace, prepare packet indexes and map dispatches
pkt_idxs:dict[str, itertools.count] = {}
dispatch_to_pc:dict[str, int] = {}
for e in sqtt_pkts[:-args.limit]:
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
if e.name.ret is not None and e.name.ret.startswith("PC:"): dispatch_to_pc[f"{e.device}-{idx}"] = int(e.name.ret.replace("PC:", ""))
# start printing from the offset point
for e in sqtt_pkts[-args.limit:]:
op_name, info = e.name.display_name, e.name.ret or ""
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
op_str = hex_colored(op_name, color) if color and not args.no_color else op_name
phase, pc = None, None
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
if info.startswith("PC:"):
dispatch_to_pc[f"{e.device}-{idx}"] = pc = int(info.replace("PC:", ""))
phase = "DISPATCH"
if info.startswith("LINK:"): phase, pc = "EXEC", dispatch_to_pc[info.replace("LINK:", "")]
if pc and phase: info = f"{phase:<8} 0x{pc:05x} {pc_map[pc]}"
print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}")
# note: we only print the important packets and skip the rest
if has_more: print(f"Selected packets {args.offset:,}-{args.offset + args.limit:,}. Use --offset and --limit to see others")
sys.exit(0)
# ** Profiler printer
agg, total, n = {}, 0, 0
if args.device is None: print("Select a device:")
for k,v in profile["layout"].items():
if not optional_eq({"name":k}, args.device): continue
print(f" {k}")
print(f" {format_colored(k)}")
if args.device is None: continue
for e in v.get("events", []):
et = e["dur"]*1e-6
@@ -98,7 +148,7 @@ if __name__ == "__main__":
if optional_eq(e, args.kernel) and n < 10:
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
name = e["name"]+(" " * (46 - ansilen(e["name"])))
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e['fmt'].replace('\n', ' | ')+" ")
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
n += 1
else:
a = agg.setdefault(e["name"], [0.0, 0])
@@ -107,14 +157,15 @@ if __name__ == "__main__":
total += et
if agg and total > 0:
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
sel = items if args.top == -1 else items[:args.top]
sel = items if args.limit == -1 else items[args.offset:args.offset+args.limit]
table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in sel]
if args.top != -1 and (other:=items[len(sel):]):
if args.limit != -1 and (other:=items[len(sel):]):
other_t = total-sum(t for _, (t, _) in sel)
table.append([f"Other ({len(other)} unique)", time_to_str(other_t, w=9), sum(c for _,(_,c) in other), f"{other_t/total*100.0:.2f}%"])
print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github"))
sys.exit(0)
# ** Graph rewrites printer
for k in viz.ctxs:
if not optional_eq(k, args.kernel): continue
print(k["name"])

View File

@@ -79,7 +79,7 @@ class TestSQTTMapBase(unittest.TestCase):
if (p:=kern_events.get(event.kern)) is None: continue
with self.subTest(example=name, kern=event.kern):
# skip if there's no SQTT frequency data
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)[0]): continue
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
mean = sum(frequency) / len(frequency)
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
@@ -101,7 +101,7 @@ class TestSQTTMapBase(unittest.TestCase):
for name, (events, kern_events, target) in self.examples.items():
for event in events:
wave_barriers = {}
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target)[0]:
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
if not wave_barriers: continue
for row, events in wave_barriers.items():

View File

@@ -337,7 +337,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
def sqtt_timeline(data:bytes, lib:bytes, target:str, max_pkts=getenv("MAX_SQTT_PKTS",50_000)) -> tuple[list[ProfileEvent], bool]:
from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC,
ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA,
WAVEEND, CDNA_WAVEEND, WAVERDY)
@@ -365,8 +365,11 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}")
has_more = False
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
if len(ret) > max_pkts:
has_more = True
break
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
pair = (p._time, p.delta)
if prev_pair is None: prev_pair = pair
@@ -393,7 +396,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
else:
add(name.replace("_ALT", ""), p, op=name)
pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret, has_more
# ** SQTT OCC only unpacks wave start, end time and SIMD location
@@ -619,7 +622,7 @@ def get_render(query:str) -> dict:
if fmt.startswith("prg-pkts"):
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(sqtt_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
if (events:=get_profile(sqtt_timeline(*data)[0], sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No SQTT trace on this SE."}
return ret
if fmt == "prg-sqtt":