mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: sqtt printer in viz/cli.py (#15411)
* work * sqtt timeline in CLI * format all printers nicely * s/Showed/Printed * ansistrip * sys.exit * keep colors in list * work from amd_copy_matmul * has_more always gets returned * linter * don't print colors * more colors * wow this is so deep * work * minor details * selected * improve progress bar * remove it * 22, global_load_vaddr is so long
This commit is contained in:
@@ -2,4 +2,4 @@
|
||||
export BENCHMARK=5
|
||||
export EVAL_BS=0
|
||||
VIZ=${VIZ:--1} examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
|
||||
extra/viz/cli.py --profile --device "AMD" --top 20
|
||||
extra/viz/cli.py --profile --device "AMD" --limit 20
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
A command line tool for exploring the VIZ trace.
|
||||
|
||||
After running with VIZ=-1, use `extra/viz/cli.py` to explore the saved trace files.
|
||||
1. Set VIZ to -1 to save the trace.
|
||||
2. Use `extra/viz/cli.py` to inspect the trace files.
|
||||
|
||||
## Inspect runtime profiling
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse, pathlib, sys, struct, json
|
||||
import argparse, pathlib, sys, struct, json, itertools
|
||||
from typing import Iterator
|
||||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
@@ -65,9 +65,12 @@ if __name__ == "__main__":
|
||||
g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace")
|
||||
g_common = parser.add_argument_group("common options")
|
||||
g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)")
|
||||
g_common.add_argument("--no-color", action="store_true", default=not (sys.stdin.isatty() and sys.stdout.isatty()),
|
||||
help="Disable colored output (default: true in non-interactive mode)")
|
||||
g_profile = parser.add_argument_group("profile options")
|
||||
g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)")
|
||||
g_profile.add_argument("--top", type=int, default=10, metavar="N", help="Number of top kernels to show (-1 for all, default: 10)")
|
||||
g_profile.add_argument("--offset", type=int, default=0, metavar="N", help="event offset (default: 0)")
|
||||
g_profile.add_argument("--limit", type=int, default=10, metavar="N", help="events to display (-1 for all, default: 10)")
|
||||
g_rewrites = parser.add_argument_group("rewrites options")
|
||||
g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME",
|
||||
help="Select an item within the chosen kernel (optional name, default: only list names)")
|
||||
@@ -83,14 +86,61 @@ if __name__ == "__main__":
|
||||
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
|
||||
viz.ctxs = viz.get_rewrites(viz.trace)
|
||||
|
||||
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
|
||||
|
||||
if args.profile:
|
||||
from tabulate import tabulate
|
||||
profile = decode_profile(viz.get_profile(viz.load_pickle(args.profile_path, default=[])))
|
||||
profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[])))
|
||||
viz.load_amd_counters(viz.ctxs, profile_data)
|
||||
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
|
||||
if s["name"].startswith("PKTS")}
|
||||
if args.device is None:
|
||||
print("Select a device:")
|
||||
for k in (*profile["layout"], *counters):
|
||||
print(f" {format_colored(k)}")
|
||||
sys.exit(0)
|
||||
|
||||
# ** SQTT printer
|
||||
if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None:
|
||||
assert args.limit > 1, f"SQTT limit must be greater than 1, got {args.limit}"
|
||||
sqtt_events, has_more = viz.sqtt_timeline(*sqtt_data, max_pkts=args.offset+args.limit)
|
||||
sqtt_pkts = [e for e in sqtt_events if type(e).__name__ == "ProfileRangeEvent"]
|
||||
pc_map = next(e.arg for e in sqtt_events if type(e).__name__ == "ProfilePointEvent" and e.key == 'pcMap')
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
|
||||
(('STORE',), '#4fa3cc'), (('IMMEDIATE',), '#f3b44a'), (('BARRIER',), '#d00000'), (('LDS',), '#9fb4a6'), (('JUMP',), '#ffb703'),
|
||||
(('JUMP_NO',), '#fb8500'), (('MESSAGE',), '#90dbf4'), (('WAVERDY',), '#1a2a2a'))
|
||||
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Info'}")
|
||||
print("-" * 90)
|
||||
# start from the first packet in trace, prepare packet indexes and map dispatches
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_pc:dict[str, int] = {}
|
||||
for e in sqtt_pkts[:-args.limit]:
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if e.name.ret is not None and e.name.ret.startswith("PC:"): dispatch_to_pc[f"{e.device}-{idx}"] = int(e.name.ret.replace("PC:", ""))
|
||||
# start printing from the offset point
|
||||
for e in sqtt_pkts[-args.limit:]:
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
|
||||
op_str = hex_colored(op_name, color) if color and not args.no_color else op_name
|
||||
phase, pc = None, None
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if info.startswith("PC:"):
|
||||
dispatch_to_pc[f"{e.device}-{idx}"] = pc = int(info.replace("PC:", ""))
|
||||
phase = "DISPATCH"
|
||||
if info.startswith("LINK:"): phase, pc = "EXEC", dispatch_to_pc[info.replace("LINK:", "")]
|
||||
if pc and phase: info = f"{phase:<8} 0x{pc:05x} {pc_map[pc]}"
|
||||
print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}")
|
||||
# note: we only print the important packets and skip the rest
|
||||
if has_more: print(f"Selected packets {args.offset:,}-{args.offset + args.limit:,}. Use --offset and --limit to see others")
|
||||
sys.exit(0)
|
||||
|
||||
# ** Profiler printer
|
||||
agg, total, n = {}, 0, 0
|
||||
if args.device is None: print("Select a device:")
|
||||
for k,v in profile["layout"].items():
|
||||
if not optional_eq({"name":k}, args.device): continue
|
||||
print(f" {k}")
|
||||
print(f" {format_colored(k)}")
|
||||
if args.device is None: continue
|
||||
for e in v.get("events", []):
|
||||
et = e["dur"]*1e-6
|
||||
@@ -98,7 +148,7 @@ if __name__ == "__main__":
|
||||
if optional_eq(e, args.kernel) and n < 10:
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
name = e["name"]+(" " * (46 - ansilen(e["name"])))
|
||||
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e['fmt'].replace('\n', ' | ')+" ")
|
||||
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
|
||||
n += 1
|
||||
else:
|
||||
a = agg.setdefault(e["name"], [0.0, 0])
|
||||
@@ -107,14 +157,15 @@ if __name__ == "__main__":
|
||||
total += et
|
||||
if agg and total > 0:
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
sel = items if args.top == -1 else items[:args.top]
|
||||
sel = items if args.limit == -1 else items[args.offset:args.offset+args.limit]
|
||||
table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in sel]
|
||||
if args.top != -1 and (other:=items[len(sel):]):
|
||||
if args.limit != -1 and (other:=items[len(sel):]):
|
||||
other_t = total-sum(t for _, (t, _) in sel)
|
||||
table.append([f"Other ({len(other)} unique)", time_to_str(other_t, w=9), sum(c for _,(_,c) in other), f"{other_t/total*100.0:.2f}%"])
|
||||
print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github"))
|
||||
sys.exit(0)
|
||||
|
||||
# ** Graph rewrites printer
|
||||
for k in viz.ctxs:
|
||||
if not optional_eq(k, args.kernel): continue
|
||||
print(k["name"])
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
if (p:=kern_events.get(event.kern)) is None: continue
|
||||
with self.subTest(example=name, kern=event.kern):
|
||||
# skip if there's no SQTT frequency data
|
||||
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue
|
||||
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)[0]): continue
|
||||
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
|
||||
mean = sum(frequency) / len(frequency)
|
||||
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
|
||||
@@ -101,7 +101,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
for name, (events, kern_events, target) in self.examples.items():
|
||||
for event in events:
|
||||
wave_barriers = {}
|
||||
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
|
||||
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target)[0]:
|
||||
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
|
||||
if not wave_barriers: continue
|
||||
for row, events in wave_barriers.items():
|
||||
|
||||
@@ -337,7 +337,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str, max_pkts=getenv("MAX_SQTT_PKTS",50_000)) -> tuple[list[ProfileEvent], bool]:
|
||||
from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC,
|
||||
ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA,
|
||||
WAVEEND, CDNA_WAVEEND, WAVERDY)
|
||||
@@ -365,8 +365,11 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
|
||||
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}")
|
||||
has_more = False
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
if len(ret) > max_pkts:
|
||||
has_more = True
|
||||
break
|
||||
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
|
||||
pair = (p._time, p.delta)
|
||||
if prev_pair is None: prev_pair = pair
|
||||
@@ -393,7 +396,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
else:
|
||||
add(name.replace("_ALT", ""), p, op=name)
|
||||
pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
|
||||
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret
|
||||
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret, has_more
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
@@ -619,7 +622,7 @@ def get_render(query:str) -> dict:
|
||||
if fmt.startswith("prg-pkts"):
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(sqtt_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
if (events:=get_profile(sqtt_timeline(*data)[0], sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No SQTT trace on this SE."}
|
||||
return ret
|
||||
if fmt == "prg-sqtt":
|
||||
|
||||
Reference in New Issue
Block a user