mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz/cli cleanups (#15511)
* one less function * work * layout * better handling of rewrites * mypy passes
This commit is contained in:
113
extra/viz/cli.py
113
extra/viz/cli.py
@@ -4,27 +4,9 @@ if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
||||
from typing import Iterator
|
||||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent
|
||||
|
||||
# ** generic helpers
|
||||
|
||||
def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip(val["name"]) == arg
|
||||
|
||||
def print_data(data:dict) -> None:
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(f"Input UOp:\n{m['uop']}")
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
|
||||
for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if data.get("src") is not None: print(data["src"])
|
||||
|
||||
# ** Profiler trace decoder
|
||||
|
||||
# 0 means None, otherwise it's an enum value
|
||||
def option(i:int) -> int|None: return None if i == 0 else i-1
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap
|
||||
|
||||
# profile decoder used in CLI and tests
|
||||
def decode_profile(data:bytes) -> dict:
|
||||
ret, off = data, 0
|
||||
def u(fmt:str) -> tuple:
|
||||
@@ -36,11 +18,14 @@ def decode_profile(data:bytes) -> dict:
|
||||
strings, dtypes, markers = json.loads(ret[off:off+index_len]).values()
|
||||
off += index_len
|
||||
layout:dict[str, dict] = {}
|
||||
# 0 means None, otherwise it's an enum value
|
||||
def option(i:int) -> int|None: return None if i == 0 else i-1
|
||||
for _ in range(layout_len):
|
||||
klen = u("<B")[0]
|
||||
k = ret[off:off+klen].decode()
|
||||
off += klen
|
||||
layout[k] = v = {"events":[]}
|
||||
v:dict = {"events":[]}
|
||||
layout[k] = v
|
||||
event_type, event_count = u("<BI")
|
||||
if event_type == 0:
|
||||
for _ in range(event_count):
|
||||
@@ -59,6 +44,11 @@ def decode_profile(data:bytes) -> dict:
|
||||
else: v["events"].append({"event":"free", "ts":ts, "key":key, "arg": {"users":[u("<IIIB") for _ in range(u("<I")[0])]}})
|
||||
return {"dur":total_dur, "peak":global_peak, "layout":layout, "markers":markers}
|
||||
|
||||
def get(data:dict, key:str):
|
||||
for k,v in data.items():
|
||||
if ansistrip(k) == key: return v
|
||||
raise RuntimeError(f'item "{key}" not found in list')
|
||||
|
||||
def main(args) -> None:
|
||||
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
|
||||
viz.ctxs = viz.get_rewrites(viz.trace)
|
||||
@@ -66,19 +56,20 @@ def main(args) -> None:
|
||||
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
|
||||
|
||||
if args.profile:
|
||||
from tabulate import tabulate
|
||||
profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[])))
|
||||
viz.load_amd_counters(viz.ctxs, profile_data)
|
||||
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
|
||||
if s["name"].startswith("PKTS")}
|
||||
events:list = viz.load_pickle(args.profile_path, default=[])
|
||||
if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
profile = decode_profile(profile_bytes)
|
||||
viz.load_amd_counters(viz.ctxs, events)
|
||||
profile["layout"].update([(f'{c["name"]} SQTT {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
|
||||
if s["name"].startswith("PKTS")])
|
||||
if args.source is None:
|
||||
print("Available sources:")
|
||||
for k in (*profile["layout"], *counters):
|
||||
for k in profile["layout"]:
|
||||
print(f" {format_colored(k)}")
|
||||
return None
|
||||
|
||||
# ** SQTT printer
|
||||
if args.source is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.source), None)) is not None:
|
||||
data = get(profile["layout"], args.source)
|
||||
if "SQTT" in args.source:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
|
||||
@@ -88,10 +79,11 @@ def main(args) -> None:
|
||||
print("-" * 90)
|
||||
pc_map:dict[int, str] = {}
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_inst:dict[str, int] = {}
|
||||
for e in viz.sqtt_timeline(*sqtt_data):
|
||||
dispatch_to_inst:dict[str, str] = {}
|
||||
for e in viz.sqtt_timeline(*data):
|
||||
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg
|
||||
if not isinstance(e, ProfileRangeEvent): continue
|
||||
assert isinstance(e.name, TracingKey)
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
|
||||
op_str = hex_colored(op_name, color) if color and not args.no_color else op_name
|
||||
@@ -102,43 +94,48 @@ def main(args) -> None:
|
||||
phase = "DISPATCH"
|
||||
if info.startswith("LINK:"): phase, inst = "EXEC", dispatch_to_inst[info.replace("LINK:", "")]
|
||||
if inst and phase: info = f"{phase:<8} {inst}"
|
||||
print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}")
|
||||
print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {info}")
|
||||
return None
|
||||
|
||||
# ** Profiler printer
|
||||
agg, total, n = {}, 0, 0
|
||||
for k,v in profile["layout"].items():
|
||||
if not optional_eq({"name":k}, args.source): continue
|
||||
print(f" {format_colored(k)}")
|
||||
if args.source is None: continue
|
||||
for e in v.get("events", []):
|
||||
et = e["dur"]*1e-6
|
||||
if args.item is not None:
|
||||
if optional_eq(e, args.item) and n < 10:
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
name = e["name"]+(" " * (46 - ansilen(e["name"])))
|
||||
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
|
||||
n += 1
|
||||
else:
|
||||
a = agg.setdefault(e["name"], [0.0, 0])
|
||||
a[0] += et
|
||||
a[1] += 1
|
||||
total += et
|
||||
agg:dict[str, tuple[float, int]] = {}
|
||||
total = 0
|
||||
for e in data.get("events", []):
|
||||
et = e["dur"] * 1e-6
|
||||
if args.item is not None:
|
||||
if ansistrip(e["name"]) == args.item:
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None)
|
||||
name = e["name"] + (" " * (46 - ansilen(e["name"])))
|
||||
print(f"{name} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ")
|
||||
else:
|
||||
t, c = agg.get(e["name"], (0.0, 0))
|
||||
agg[e["name"]] = (t+et, c+1)
|
||||
total += et
|
||||
if agg and total > 0:
|
||||
from tabulate import tabulate
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in items]
|
||||
print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github"))
|
||||
return None
|
||||
|
||||
# ** Graph rewrites printer
|
||||
for k in viz.ctxs:
|
||||
if not optional_eq(k, args.source): continue
|
||||
print(k["name"])
|
||||
if args.source is None: continue
|
||||
for s in k["steps"]:
|
||||
if not optional_eq(s, args.item): continue
|
||||
print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else ''))
|
||||
if args.item is not None: print_data(viz.get_render(s['query']))
|
||||
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.ctxs if c.get("steps")}
|
||||
if args.source is None:
|
||||
for k in rewrites: print(f" {format_colored(k)}")
|
||||
return None
|
||||
steps = get(rewrites, args.source)
|
||||
if args.item is None:
|
||||
for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else ''))
|
||||
else:
|
||||
data = viz.get_render(get(steps, args.item)["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(f"Input UOp:\n{m['uop']}")
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
|
||||
for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if data.get("src") is not None: print(data["src"])
|
||||
|
||||
def get_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest, pickle, contextlib, io
|
||||
from typing import Iterator
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG, getenv, temp
|
||||
from tinygrad.helpers import DEBUG, getenv, temp, ansistrip
|
||||
from tinygrad.renderer.amd.sqtt import print_packets, map_insts
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm
|
||||
from tinygrad.viz.serve import sqtt_timeline
|
||||
@@ -122,7 +122,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
out = run_cli("--profile", "--profile-path", str(pkl_path))
|
||||
sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l]
|
||||
for name in sqtt_traces:
|
||||
out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", name)
|
||||
out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", ansistrip(name))
|
||||
lines = out.split("\n")
|
||||
self.assertIn("Clk", lines[0])
|
||||
for r in lines[2:]:
|
||||
|
||||
Reference in New Issue
Block a user