viz/cli: multi device profiler output, print markers (#15795)

* yield

* all devices

* better

* add unittests

* markers like this

* profile_markers work

* less

* update README

* tiny and null
This commit is contained in:
qazal
2026-04-17 23:40:10 +03:00
committed by GitHub
parent 0191cc73dc
commit 2581985532
3 changed files with 50 additions and 28 deletions

View File

@@ -14,11 +14,11 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server
Use `extra/viz/cli.py --profile` to list all sources.
```bash
# View top 40 slowest kernels and their AST (DEBUG=4 to see source code)
# View top 40 slowest kernels on the AMD device and their AST (DEBUG=4 to see source code)
DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40
# Reconstruct DEBUG=3 output exactly as the runtime.
DEBUG=3 extra/viz/cli.py --profile -s AMD
# Reconstruct DEBUG=3 output exactly as the runtime. (all devices)
DEBUG=3 extra/viz/cli.py --profile -s ALL
```
## Inspect codegen and PatternMatcher

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import argparse, pathlib, signal, sys, struct, json, itertools, os
import argparse, pathlib, signal, sys, struct, json, os, itertools, heapq
os.environ["VIZ"] = "0"
if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
from typing import Iterator
@@ -84,10 +84,10 @@ def main(args) -> None:
profile = decode_profile(profile_bytes)
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]]))
if args.src is None: return print("Select a source with -s"+"\n ALL\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]]))
# ** SQTT printer
data = get(profile["layout"], args.src)
data = None if args.src == "ALL" else get(profile["layout"], args.src)
if "SQTT" in args.src:
# modern terminals support 24-bit color
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
@@ -134,7 +134,7 @@ def main(args) -> None:
print(fmt(pmc_data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in pmc_data[1:]])))
# ** Memory printer
elif data["event_type"] == 1:
elif data is not None and data["event_type"] == 1:
print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")
for e in data["events"]:
info = str(e.get("arg", {}))
@@ -143,38 +143,48 @@ def main(args) -> None:
print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")
# ** Profiler printer
elif data["event_type"] == 0:
kernels:list[dict] = []
if args.top:
agg:dict[str, tuple[float, int, int|None]] = {} # map kernel name to (total time, count and ref)
else:
timelines = [(n,l) for n,l in profile["layout"].items() if l.get("event_type") == 0]
def produce_top_kernels() -> Iterator[dict]:
tagged = ((n,e) for n,l in timelines for e in l["events"]) if args.src == "ALL" else ((args.src,e) for e in data["events"])
agg:dict[tuple[str,str], tuple[float, int, int|None]] = {} # map (device, kernel name) to (total time, count and ref)
total = 0
for e in data["events"]:
for dev,e in tagged:
et = e["dur"] * 1e-6
t, c, ref = agg.get(e["name"], (0.0, 0, None))
agg[e["name"]] = (t+et, c+1, e["ref"])
t, c, ref = agg.get((dev,e["name"]), (0.0, 0, None))
agg[(dev,e["name"])] = (t+et, c+1, e["ref"])
total += et
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
num_rows = len(items) if args.top < 0 else args.top
for name,(t,c,ref) in items[:num_rows]:
kernels.append({"name":name, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref})
for (dev,name),(t,c,ref) in items[:num_rows]:
display = f"{dev[:7]:7s} {name}" if args.src == "ALL" else name
yield {"name":display, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref}
if num_rows > 0 and items[num_rows:]:
other_t = sum(t for _,(t,_,_) in items[num_rows:])
other_c = sum(c for _,(_,c,_) in items[num_rows:])
kernels.append({"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None})
else:
st0 = data["events"][0]["st"] if data["events"] else 0
for k,e in enumerate(data["events"]):
et, timestamp = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6
yield {"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None}
def produce_all_kernels() -> Iterator[dict]:
st0:int|None = None
event_streams = [[(e["st"], n, e) for e in l["events"]] for n,l in timelines] if args.src == "ALL" \
else [[(e["st"], args.src, e) for e in data["events"]]]
marker_stream = sorted([(m["ts"], "MARKER", m) for m in profile.get("markers", [])], key=lambda t:t[0])
for ts,dev,e in heapq.merge(*event_streams, marker_stream, key=lambda t:t[0]):
if st0 is None: st0 = ts
if dev == "MARKER":
yield {"name":f"--- MARKER {e['name']}", "fmt":f"@ {(ts-st0)*1e-3:9.2f}ms", "ref":None, "ext":None}
continue
et, timestamp, ext = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6, None
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None)
if e["fmt"].startswith("TB:"): e["fmt"] = "" # TODO: print python backtrace at a reasonable DEBUG level
fmt_str = " ".join(p+" "*max(0, 14-ansilen(p)) for p in e["fmt"].split("\n"))
name = f"*** {args.src[:7]:7s} {k+1:4d} "+e["name"]+" "*(46-ansilen(e["name"]))
kernels.append({"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"]})
for k in kernels:
name = f"*** {dev[:7]:7s} "+e["name"]+" "*(46-ansilen(e["name"]))
yield {"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"], "ext":ext}
for k in (produce_top_kernels if args.top else produce_all_kernels)():
print(f"{fmt_colored(k['name'])}{' ' * max(0, 36 - ansilen(k['name']))} {k['fmt']}")
if k["ref"] is not None:
steps = rewrites[viz_data.ctxs[k["ref"]]["name"]]
if DEBUG >= 3 and (ast_step:=steps.get("View Base AST")) is not None: print_step(ast_step)
if DEBUG >= 4: print_step(steps["View Source"])
if DEBUG >= 4 and (src_step:=steps.get("View Source")) is not None: print_step(src_step)
def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False)

View File

@@ -1,4 +1,4 @@
import unittest, decimal, sys, json, contextlib, tempfile, pickle, io
import unittest, decimal, sys, json, contextlib, tempfile, pickle, io, itertools
from pathlib import Path
from dataclasses import dataclass
from typing import Generator
@@ -896,21 +896,33 @@ def run_cli(*cli_args) -> str:
class TestCLI(unittest.TestCase):
def test_simple(self):
a = Tensor.empty(1, device="NULL")+2.0
empty_counter = itertools.count(0)
def custom_empty_prg(B:UOp, A:UOp) -> UOp:
sink = UOp(Ops.SINK, arg=KernelInfo(name="custom_empty"))
sink = UOp(Ops.SINK, arg=KernelInfo(name=f"custom_empty_n{next(empty_counter)}"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=a.device), UOp(Ops.LINEAR, src=(sink,))))
b = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_empty_prg)[0]
c = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_empty_prg)[0]
with save_viz() as viz:
b.realize()
profile_marker("marker @ 1")
c.realize()
# save trace to disk for CLI to consume it
with tempfile.TemporaryDirectory() as tmpdir:
(r:=Path(tmpdir)/"rewrites.pkl").write_bytes(pickle.dumps(viz.data.trace))
(p:=Path(tmpdir)/"profile.pkl").write_bytes(pickle.dumps(cpu_events))
# reconstruct DEBUG=4 output and see all markers.
with Context(DEBUG=4):
kernels = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "NULL")
self.assertIn("void custom_empty", kernels)
self.assertIn("void custom_empty_n0", kernels)
self.assertIn("marker @ 1", kernels)
self.assertIn("void custom_empty_n1", kernels)
self.assertIn("E", kernels)
self.assertIn("UOp.const", kernels)
# get the top slowest functions across all devices
with Context(DEBUG=2):
times = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "ALL", "--top", "-1")
self.assertIn("TINY", times)
self.assertIn("NULL", times)
if __name__ == "__main__":
unittest.main()