mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz/cli: multi device profiler output, print markers (#15795)
* yield * all devices * better * add unittests * markers like this * profile_markers work * less * update README * tiny and null
This commit is contained in:
@@ -14,11 +14,11 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server
|
||||
Use `extra/viz/cli.py --profile` to list all sources.
|
||||
|
||||
```bash
|
||||
# View top 40 slowest kernels and their AST (DEBUG=4 to see source code)
|
||||
# View top 40 slowest kernels on the AMD device and their AST (DEBUG=4 to see source code)
|
||||
DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40
|
||||
|
||||
# Reconstruct DEBUG=3 output exactly as the runtime.
|
||||
DEBUG=3 extra/viz/cli.py --profile -s AMD
|
||||
# Reconstruct DEBUG=3 output exactly as the runtime. (all devices)
|
||||
DEBUG=3 extra/viz/cli.py --profile -s ALL
|
||||
```
|
||||
|
||||
## Inspect codegen and PatternMatcher
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse, pathlib, signal, sys, struct, json, itertools, os
|
||||
import argparse, pathlib, signal, sys, struct, json, os, itertools, heapq
|
||||
os.environ["VIZ"] = "0"
|
||||
if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
||||
from typing import Iterator
|
||||
@@ -84,10 +84,10 @@ def main(args) -> None:
|
||||
profile = decode_profile(profile_bytes)
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]]))
|
||||
if args.src is None: return print("Select a source with -s"+"\n ALL\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]]))
|
||||
|
||||
# ** SQTT printer
|
||||
data = get(profile["layout"], args.src)
|
||||
data = None if args.src == "ALL" else get(profile["layout"], args.src)
|
||||
if "SQTT" in args.src:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
@@ -134,7 +134,7 @@ def main(args) -> None:
|
||||
print(fmt(pmc_data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in pmc_data[1:]])))
|
||||
|
||||
# ** Memory printer
|
||||
elif data["event_type"] == 1:
|
||||
elif data is not None and data["event_type"] == 1:
|
||||
print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")
|
||||
for e in data["events"]:
|
||||
info = str(e.get("arg", {}))
|
||||
@@ -143,38 +143,48 @@ def main(args) -> None:
|
||||
print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")
|
||||
|
||||
# ** Profiler printer
|
||||
elif data["event_type"] == 0:
|
||||
kernels:list[dict] = []
|
||||
if args.top:
|
||||
agg:dict[str, tuple[float, int, int|None]] = {} # map kernel name to (total time, count and ref)
|
||||
else:
|
||||
timelines = [(n,l) for n,l in profile["layout"].items() if l.get("event_type") == 0]
|
||||
def produce_top_kernels() -> Iterator[dict]:
|
||||
tagged = ((n,e) for n,l in timelines for e in l["events"]) if args.src == "ALL" else ((args.src,e) for e in data["events"])
|
||||
agg:dict[tuple[str,str], tuple[float, int, int|None]] = {} # map (device, kernel name) to (total time, count and ref)
|
||||
total = 0
|
||||
for e in data["events"]:
|
||||
for dev,e in tagged:
|
||||
et = e["dur"] * 1e-6
|
||||
t, c, ref = agg.get(e["name"], (0.0, 0, None))
|
||||
agg[e["name"]] = (t+et, c+1, e["ref"])
|
||||
t, c, ref = agg.get((dev,e["name"]), (0.0, 0, None))
|
||||
agg[(dev,e["name"])] = (t+et, c+1, e["ref"])
|
||||
total += et
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
num_rows = len(items) if args.top < 0 else args.top
|
||||
for name,(t,c,ref) in items[:num_rows]:
|
||||
kernels.append({"name":name, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref})
|
||||
for (dev,name),(t,c,ref) in items[:num_rows]:
|
||||
display = f"{dev[:7]:7s} {name}" if args.src == "ALL" else name
|
||||
yield {"name":display, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref}
|
||||
if num_rows > 0 and items[num_rows:]:
|
||||
other_t = sum(t for _,(t,_,_) in items[num_rows:])
|
||||
other_c = sum(c for _,(_,c,_) in items[num_rows:])
|
||||
kernels.append({"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None})
|
||||
else:
|
||||
st0 = data["events"][0]["st"] if data["events"] else 0
|
||||
for k,e in enumerate(data["events"]):
|
||||
et, timestamp = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6
|
||||
yield {"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None}
|
||||
def produce_all_kernels() -> Iterator[dict]:
|
||||
st0:int|None = None
|
||||
event_streams = [[(e["st"], n, e) for e in l["events"]] for n,l in timelines] if args.src == "ALL" \
|
||||
else [[(e["st"], args.src, e) for e in data["events"]]]
|
||||
marker_stream = sorted([(m["ts"], "MARKER", m) for m in profile.get("markers", [])], key=lambda t:t[0])
|
||||
for ts,dev,e in heapq.merge(*event_streams, marker_stream, key=lambda t:t[0]):
|
||||
if st0 is None: st0 = ts
|
||||
if dev == "MARKER":
|
||||
yield {"name":f"--- MARKER {e['name']}", "fmt":f"@ {(ts-st0)*1e-3:9.2f}ms", "ref":None, "ext":None}
|
||||
continue
|
||||
et, timestamp, ext = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6, None
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None)
|
||||
if e["fmt"].startswith("TB:"): e["fmt"] = "" # TODO: print python backtrace at a reasonable DEBUG level
|
||||
fmt_str = " ".join(p+" "*max(0, 14-ansilen(p)) for p in e["fmt"].split("\n"))
|
||||
name = f"*** {args.src[:7]:7s} {k+1:4d} "+e["name"]+" "*(46-ansilen(e["name"]))
|
||||
kernels.append({"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"]})
|
||||
for k in kernels:
|
||||
name = f"*** {dev[:7]:7s} "+e["name"]+" "*(46-ansilen(e["name"]))
|
||||
yield {"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"], "ext":ext}
|
||||
for k in (produce_top_kernels if args.top else produce_all_kernels)():
|
||||
print(f"{fmt_colored(k['name'])}{' ' * max(0, 36 - ansilen(k['name']))} {k['fmt']}")
|
||||
if k["ref"] is not None:
|
||||
steps = rewrites[viz_data.ctxs[k["ref"]]["name"]]
|
||||
if DEBUG >= 3 and (ast_step:=steps.get("View Base AST")) is not None: print_step(ast_step)
|
||||
if DEBUG >= 4: print_step(steps["View Source"])
|
||||
if DEBUG >= 4 and (src_step:=steps.get("View Source")) is not None: print_step(src_step)
|
||||
|
||||
def get_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import unittest, decimal, sys, json, contextlib, tempfile, pickle, io
|
||||
import unittest, decimal, sys, json, contextlib, tempfile, pickle, io, itertools
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator
|
||||
@@ -896,21 +896,33 @@ def run_cli(*cli_args) -> str:
|
||||
class TestCLI(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
a = Tensor.empty(1, device="NULL")+2.0
|
||||
empty_counter = itertools.count(0)
|
||||
def custom_empty_prg(B:UOp, A:UOp) -> UOp:
|
||||
sink = UOp(Ops.SINK, arg=KernelInfo(name="custom_empty"))
|
||||
sink = UOp(Ops.SINK, arg=KernelInfo(name=f"custom_empty_n{next(empty_counter)}"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=a.device), UOp(Ops.LINEAR, src=(sink,))))
|
||||
b = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_empty_prg)[0]
|
||||
c = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_empty_prg)[0]
|
||||
with save_viz() as viz:
|
||||
b.realize()
|
||||
profile_marker("marker @ 1")
|
||||
c.realize()
|
||||
# save trace to disk for CLI to consume it
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
(r:=Path(tmpdir)/"rewrites.pkl").write_bytes(pickle.dumps(viz.data.trace))
|
||||
(p:=Path(tmpdir)/"profile.pkl").write_bytes(pickle.dumps(cpu_events))
|
||||
# reconstruct DEBUG=4 output and see all markers.
|
||||
with Context(DEBUG=4):
|
||||
kernels = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "NULL")
|
||||
self.assertIn("void custom_empty", kernels)
|
||||
self.assertIn("void custom_empty_n0", kernels)
|
||||
self.assertIn("marker @ 1", kernels)
|
||||
self.assertIn("void custom_empty_n1", kernels)
|
||||
self.assertIn("E", kernels)
|
||||
self.assertIn("UOp.const", kernels)
|
||||
# get the top slowest functions across all devices
|
||||
with Context(DEBUG=2):
|
||||
times = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "ALL", "--top", "-1")
|
||||
self.assertIn("TINY", times)
|
||||
self.assertIn("NULL", times)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user