From ccaa6bfc1987dfe0a3d1bc3ae6111d912d8fc5c1 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 28 Mar 2026 01:50:38 +0200 Subject: [PATCH] viz/cli cleanups (#15511) * one less function * work * layout * better handling of rewrites * mypy passes --- extra/viz/cli.py | 113 +++++++++++++++++++-------------------- test/amd/test_sqttmap.py | 4 +- 2 files changed, 57 insertions(+), 60 deletions(-) diff --git a/extra/viz/cli.py b/extra/viz/cli.py index a9af99553b..c8d1288d04 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -4,27 +4,9 @@ if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL) from typing import Iterator from tinygrad.viz import serve as viz from tinygrad.uop.ops import RewriteTrace -from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent - -# ** generic helpers - -def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip(val["name"]) == arg - -def print_data(data:dict) -> None: - if isinstance(data.get("value"), Iterator): - for m in data["value"]: - if m.get("uop"): print(f"Input UOp:\n{m['uop']}") - if m.get("diff"): - loc = pathlib.Path(m["upat"][0][0]) - print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}") - for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) - if data.get("src") is not None: print(data["src"]) - -# ** Profiler trace decoder - -# 0 means None, otherwise it's an enum value -def option(i:int) -> int|None: return None if i == 0 else i-1 +from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap +# profile decoder used in CLI and tests def decode_profile(data:bytes) -> dict: ret, off = data, 0 def u(fmt:str) -> tuple: @@ -36,11 +18,14 @@ def decode_profile(data:bytes) -> dict: strings, dtypes, markers = json.loads(ret[off:off+index_len]).values() off += index_len layout:dict[str, dict] = {} + # 0 means None, otherwise it's an enum value + def option(i:int) -> int|None: return None if i == 0 else i-1 for _ in range(layout_len): klen = u(" dict: else: v["events"].append({"event":"free", "ts":ts, "key":key, "arg": {"users":[u(" None: viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})) viz.ctxs = viz.get_rewrites(viz.trace) @@ -66,19 +56,20 @@ def main(args) -> None: def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s if args.profile: - from tabulate import tabulate - profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[]))) - viz.load_amd_counters(viz.ctxs, profile_data) - counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"] - if s["name"].startswith("PKTS")} + events:list = viz.load_pickle(args.profile_path, default=[]) + if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") + profile = decode_profile(profile_bytes) + viz.load_amd_counters(viz.ctxs, events) + profile["layout"].update([(f'{c["name"]} SQTT {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"] + if s["name"].startswith("PKTS")]) if args.source is None: - print("Available sources:") - for k in (*profile["layout"], *counters): + for k in profile["layout"]: print(f" {format_colored(k)}") return None # ** SQTT printer - if args.source is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.source), None)) is not None: + data = get(profile["layout"], args.source) + if "SQTT" in args.source: # modern terminals support 24-bit color def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m" WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'), @@ -88,10 +79,11 @@ def main(args) -> None: print("-" * 90) pc_map:dict[int, str] = {} pkt_idxs:dict[str, itertools.count] = {} - dispatch_to_inst:dict[str, int] = {} - for e in viz.sqtt_timeline(*sqtt_data): + dispatch_to_inst:dict[str, str] = {} + for e in viz.sqtt_timeline(*data): if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg if not isinstance(e, ProfileRangeEvent): continue + assert isinstance(e.name, TracingKey) op_name, info = e.name.display_name, e.name.ret or "" color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None) op_str = hex_colored(op_name, color) if color and not args.no_color else op_name @@ -102,43 +94,48 @@ def main(args) -> None: phase = "DISPATCH" if info.startswith("LINK:"): phase, inst = "EXEC", dispatch_to_inst[info.replace("LINK:", "")] if inst and phase: info = f"{phase:<8} {inst}" - print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}") + print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {info}") return None # ** Profiler printer - agg, total, n = {}, 0, 0 - for k,v in profile["layout"].items(): - if not optional_eq({"name":k}, args.source): continue - print(f" {format_colored(k)}") - if args.source is None: continue - for e in v.get("events", []): - et = e["dur"]*1e-6 - if args.item is not None: - if optional_eq(e, args.item) and n < 10: - ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" - name = e["name"]+(" " * (46 - ansilen(e["name"]))) - print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ") - n += 1 - else: - a = agg.setdefault(e["name"], [0.0, 0]) - a[0] += et - a[1] += 1 - total += et + agg:dict[str, tuple[float, int]] = {} + total = 0 + for e in data.get("events", []): + et = e["dur"] * 1e-6 + if args.item is not None: + if ansistrip(e["name"]) == args.item: + ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) + name = e["name"] + (" " * (46 - ansilen(e["name"]))) + print(f"{name} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ") + else: + t, c = agg.get(e["name"], (0.0, 0)) + agg[e["name"]] = (t+et, c+1) + total += et if agg and total > 0: + from tabulate import tabulate items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True) table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in items] print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github")) return None # ** Graph rewrites printer - for k in viz.ctxs: - if not optional_eq(k, args.source): continue - print(k["name"]) - if args.source is None: continue - for s in k["steps"]: - if not optional_eq(s, args.item): continue - print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else '')) - if args.item is not None: print_data(viz.get_render(s['query'])) + rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.ctxs if c.get("steps")} + if args.source is None: + for k in rewrites: print(f" {format_colored(k)}") + return None + steps = get(rewrites, args.source) + if args.item is None: + for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else '')) + else: + data = viz.get_render(get(steps, args.item)["query"]) + if isinstance(data.get("value"), Iterator): + for m in data["value"]: + if m.get("uop"): print(f"Input UOp:\n{m['uop']}") + if m.get("diff"): + loc = pathlib.Path(m["upat"][0][0]) + print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}") + for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) + if data.get("src") is not None: print(data["src"]) def get_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index cbfa6f3a62..24dfb36c68 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -2,7 +2,7 @@ import unittest, pickle, contextlib, io from typing import Iterator from pathlib import Path -from tinygrad.helpers import DEBUG, getenv, temp +from tinygrad.helpers import DEBUG, getenv, temp, ansistrip from tinygrad.renderer.amd.sqtt import print_packets, map_insts from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm from tinygrad.viz.serve import sqtt_timeline @@ -122,7 +122,7 @@ class TestSQTTMapBase(unittest.TestCase): out = run_cli("--profile", "--profile-path", str(pkl_path)) sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l] for name in sqtt_traces: - out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", name) + out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", ansistrip(name)) lines = out.split("\n") self.assertIn("Clk", lines[0]) for r in lines[2:]: