diff --git a/extra/viz/README.md b/extra/viz/README.md index 5bdb2f904c..875dfb34e5 100644 --- a/extra/viz/README.md +++ b/extra/viz/README.md @@ -14,11 +14,11 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server Use `extra/viz/cli.py --profile` to list all sources. ```bash -# View top 40 slowest kernels and their AST (DEBUG=4 to see source code) +# View top 40 slowest kernels on the AMD device and their AST (DEBUG=4 to see source code) DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40 -# Reconstruct DEBUG=3 output exactly as the runtime. -DEBUG=3 extra/viz/cli.py --profile -s AMD +# Reconstruct DEBUG=3 output exactly as the runtime. (all devices) +DEBUG=3 extra/viz/cli.py --profile -s ALL ``` ## Inspect codegen and PatternMatcher diff --git a/extra/viz/cli.py b/extra/viz/cli.py index c9a6fe04ef..d28356ccf6 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import argparse, pathlib, signal, sys, struct, json, itertools, os +import argparse, pathlib, signal, sys, struct, json, os, itertools, heapq os.environ["VIZ"] = "0" if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL) from typing import Iterator @@ -84,10 +84,10 @@ def main(args) -> None: profile = decode_profile(profile_bytes) profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))]) - if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]])) + if args.src is None: return print("Select a source with -s"+"\n ALL\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]])) # ** SQTT printer - data = get(profile["layout"], args.src) + data = None if args.src == "ALL" else get(profile["layout"], args.src) if "SQTT" in args.src: # modern terminals support 24-bit color def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m" @@ -134,7 +134,7 @@ def main(args) -> None: print(fmt(pmc_data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in pmc_data[1:]]))) # ** Memory printer - elif data["event_type"] == 1: + elif data is not None and data["event_type"] == 1: print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info") for e in data["events"]: info = str(e.get("arg", {})) @@ -143,38 +143,48 @@ def main(args) -> None: print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}") # ** Profiler printer - elif data["event_type"] == 0: - kernels:list[dict] = [] - if args.top: - agg:dict[str, tuple[float, int, int|None]] = {} # map kernel name to (total time, count and ref) + else: + timelines = [(n,l) for n,l in profile["layout"].items() if l.get("event_type") == 0] + def produce_top_kernels() -> Iterator[dict]: + tagged = ((n,e) for n,l in timelines for e in l["events"]) if args.src == "ALL" else ((args.src,e) for e in data["events"]) + agg:dict[tuple[str,str], tuple[float, int, int|None]] = {} # map (device, kernel name) to (total time, count and ref) total = 0 - for e in data["events"]: + for dev,e in tagged: et = e["dur"] * 1e-6 - t, c, ref = agg.get(e["name"], (0.0, 0, None)) - agg[e["name"]] = (t+et, c+1, e["ref"]) + t, c, ref = agg.get((dev,e["name"]), (0.0, 0, None)) + agg[(dev,e["name"])] = (t+et, c+1, e["ref"]) total += et items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True) num_rows = len(items) if args.top < 0 else args.top - for name,(t,c,ref) in items[:num_rows]: - kernels.append({"name":name, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref}) + for (dev,name),(t,c,ref) in items[:num_rows]: + display = f"{dev[:7]:7s} {name}" if args.src == "ALL" else name + yield {"name":display, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref} if num_rows > 0 and items[num_rows:]: other_t = sum(t for _,(t,_,_) in items[num_rows:]) other_c = sum(c for _,(_,c,_) in items[num_rows:]) - kernels.append({"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None}) - else: - st0 = data["events"][0]["st"] if data["events"] else 0 - for k,e in enumerate(data["events"]): - et, timestamp = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6 + yield {"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None} + def produce_all_kernels() -> Iterator[dict]: + st0:int|None = None + event_streams = [[(e["st"], n, e) for e in l["events"]] for n,l in timelines] if args.src == "ALL" \ + else [[(e["st"], args.src, e) for e in data["events"]]] + marker_stream = sorted([(m["ts"], "MARKER", m) for m in profile.get("markers", [])], key=lambda t:t[0]) + for ts,dev,e in heapq.merge(*event_streams, marker_stream, key=lambda t:t[0]): + if st0 is None: st0 = ts + if dev == "MARKER": + yield {"name":f"--- MARKER {e['name']}", "fmt":f"@ {(ts-st0)*1e-3:9.2f}ms", "ref":None, "ext":None} + continue + et, timestamp, ext = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6, None ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) + if e["fmt"].startswith("TB:"): e["fmt"] = "" # TODO: print python backtrace at a reasonable DEBUG level fmt_str = " ".join(p+" "*max(0, 14-ansilen(p)) for p in e["fmt"].split("\n")) - name = f"*** {args.src[:7]:7s} {k+1:4d} "+e["name"]+" "*(46-ansilen(e["name"])) - kernels.append({"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"]}) - for k in kernels: + name = f"*** {dev[:7]:7s} "+e["name"]+" "*(46-ansilen(e["name"])) + yield {"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"], "ext":ext} + for k in (produce_top_kernels if args.top else produce_all_kernels)(): print(f"{fmt_colored(k['name'])}{' ' * max(0, 36 - ansilen(k['name']))} {k['fmt']}") if k["ref"] is not None: steps = rewrites[viz_data.ctxs[k["ref"]]["name"]] if DEBUG >= 3 and (ast_step:=steps.get("View Base AST")) is not None: print_step(ast_step) - if DEBUG >= 4: print_step(steps["View Source"]) + if DEBUG >= 4 and (src_step:=steps.get("View Source")) is not None: print_step(src_step) def get_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(add_help=False) diff --git a/test/null/test_viz.py b/test/null/test_viz.py index 680fb4225a..f337593e96 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -1,4 +1,4 @@ -import unittest, decimal, sys, json, contextlib, tempfile, pickle, io +import unittest, decimal, sys, json, contextlib, tempfile, pickle, io, itertools from pathlib import Path from dataclasses import dataclass from typing import Generator @@ -896,21 +896,33 @@ def run_cli(*cli_args) -> str: class TestCLI(unittest.TestCase): def test_simple(self): a = Tensor.empty(1, device="NULL")+2.0 + empty_counter = itertools.count(0) def custom_empty_prg(B:UOp, A:UOp) -> UOp: - sink = UOp(Ops.SINK, arg=KernelInfo(name="custom_empty")) + sink = UOp(Ops.SINK, arg=KernelInfo(name=f"custom_empty_n{next(empty_counter)}")) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=a.device), UOp(Ops.LINEAR, src=(sink,)))) b = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_empty_prg)[0] + c = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_empty_prg)[0] with save_viz() as viz: b.realize() + profile_marker("marker @ 1") + c.realize() # save trace to disk for CLI to consume it with tempfile.TemporaryDirectory() as tmpdir: (r:=Path(tmpdir)/"rewrites.pkl").write_bytes(pickle.dumps(viz.data.trace)) (p:=Path(tmpdir)/"profile.pkl").write_bytes(pickle.dumps(cpu_events)) + # reconstruct DEBUG=4 output and see all markers. with Context(DEBUG=4): kernels = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "NULL") - self.assertIn("void custom_empty", kernels) + self.assertIn("void custom_empty_n0", kernels) + self.assertIn("marker @ 1", kernels) + self.assertIn("void custom_empty_n1", kernels) self.assertIn("E", kernels) self.assertIn("UOp.const", kernels) + # get the top slowest functions across all devices + with Context(DEBUG=2): + times = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "ALL", "--top", "-1") + self.assertIn("TINY", times) + self.assertIn("NULL", times) if __name__ == "__main__": unittest.main()