mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz/cli: --json support, refine docs (#15528)
* refine * remove * refine * keep * need to say this * back * feedback * feedback * json * dur_ms * et_ms * remove useless thing * docs * respect NO_COLOR * DEBUG also produces valid json
This commit is contained in:
@@ -4,21 +4,30 @@ A command line tool for exploring the VIZ trace.
|
||||
|
||||
Supported on all backends.
|
||||
|
||||
Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server.
|
||||
Flags: VIZ=-1 to only save the trace to a file.
|
||||
|
||||
1. Set VIZ to -1 to save the trace.
|
||||
2. Use `extra/viz/cli.py` to inspect the trace files. Set NO_COLOR=1 to disable colored output.
|
||||
By default, VIZ CLI automatically loads the latest trace files.
|
||||
|
||||
## Inspect runtime profiling
|
||||
|
||||
Use `extra/viz/cli.py --profile` to list all sources.
|
||||
Use `extra/viz/cli.py --profile -s ALL` to inspect the complete timing data of kernels, JIT, codegen and scheduling.
|
||||
|
||||
- Add DEBUG=3 to see AST, DEGUG=4 to also see source code.
|
||||
- Make sure to add NO_COLOR=1 to disable colored output.
|
||||
- Add --jsonl to see JSON output
|
||||
|
||||
```bash
|
||||
# Extract the AST of all kernels
|
||||
DEBUG=3 extra/viz/cli.py --profile -s ALL > asts.txt
|
||||
|
||||
# Get kernel timing information in JSONL format
|
||||
extra/viz/cli.py --profile -s ALL --jsonl
|
||||
|
||||
# View top 40 slowest kernels on the AMD device and their AST (DEBUG=4 to see source code)
|
||||
DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40
|
||||
|
||||
# Reconstruct DEBUG=3 output exactly as the runtime. (all devices)
|
||||
DEBUG=3 extra/viz/cli.py --profile -s ALL
|
||||
# List top 10 slowest operations across all devices
|
||||
extra/viz/cli.py --profile --top 10 -s ALL
|
||||
```
|
||||
|
||||
## Inspect codegen and PatternMatcher
|
||||
@@ -26,14 +35,13 @@ DEBUG=3 extra/viz/cli.py --profile -s ALL
|
||||
Use `extra/viz/cli.py --rewrites` to list all sources.
|
||||
|
||||
List all codegen steps for a kernel: `--rewrites -s E_3`
|
||||
Get source code: `--rewrites -s E_3 -i "View Source"`
|
||||
Inspect a graph rewrite: `--rewrites -s E_3 -i "initial symbolic"`
|
||||
|
||||
## SQTT tracing
|
||||
|
||||
Supported on AMD for RDNA3 and RDNA4 (best) and CDNA (developing).
|
||||
|
||||
Flags: VIZ=-2 to save SQTT trace to a file. VIZ=2 also launches a web server. View other flags in tinygrad/runtime/ops_amd.py to configure SQTT as needed.
|
||||
Flags: VIZ=-2 to save SQTT trace to a file. View other flags in tinygrad/runtime/ops_amd.py to configure SQTT as needed.
|
||||
|
||||
Use `extra/viz/cli.py --profile | grep SQTT` to view all available SQTT traces.
|
||||
You can select a specific trace with --source, Example workflow:
|
||||
|
||||
@@ -63,12 +63,12 @@ def main(args) -> None:
|
||||
data = viz.get_render(viz_data, step["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(m["uop"])
|
||||
if m.get("uop"): print(json.dumps({"ast":m["uop"]}) if args.jsonl else m["uop"])
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
|
||||
for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if data.get("src") is not None: print(data["src"])
|
||||
if data.get("src") is not None: print(json.dumps({"src":data["src"]}) if args.jsonl else data["src"])
|
||||
|
||||
# ** Graph rewrites printer
|
||||
if args.rewrites:
|
||||
@@ -150,37 +150,42 @@ def main(args) -> None:
|
||||
agg:dict[tuple[str,str], tuple[float, int, int|None]] = {} # map (device, kernel name) to (total time, count and ref)
|
||||
total = 0
|
||||
for dev,e in tagged:
|
||||
et = e["dur"] * 1e-6
|
||||
et = e["dur"] * 1e-3
|
||||
t, c, ref = agg.get((dev,e["name"]), (0.0, 0, None))
|
||||
agg[(dev,e["name"])] = (t+et, c+1, e["ref"])
|
||||
total += et
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
num_rows = len(items) if args.top < 0 else args.top
|
||||
for (dev,name),(t,c,ref) in items[:num_rows]:
|
||||
display = f"{dev[:7]:7s} {name}" if args.src == "ALL" else name
|
||||
yield {"name":display, "fmt":f"{time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%", "ref":ref}
|
||||
display = f"{dev[:7]:7s} {fmt_colored(name)}" if args.src == "ALL" else name
|
||||
yield {"name":display, "dur_ms":t, "count":c, "pct":t/total*100.0, "ref":ref}
|
||||
if num_rows > 0 and items[num_rows:]:
|
||||
other_t = sum(t for _,(t,_,_) in items[num_rows:])
|
||||
other_c = sum(c for _,(_,c,_) in items[num_rows:])
|
||||
yield {"name":"Other", "fmt":f"{time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%", "ref":None}
|
||||
yield {"name":"Other", "dur_ms":other_t, "count":other_c, "pct":other_t/total*100.0, "ref":None}
|
||||
def produce_all_kernels() -> Iterator[dict]:
|
||||
st0:int|None = None
|
||||
event_streams = [[(e["st"], n, e) for e in l["events"]] for n,l in timelines] if args.src == "ALL" \
|
||||
else [[(e["st"], args.src, e) for e in data["events"]]]
|
||||
marker_stream = sorted([(m["ts"], "MARKER", m) for m in profile.get("markers", [])], key=lambda t:t[0])
|
||||
for ts,dev,e in heapq.merge(*event_streams, marker_stream, key=lambda t:t[0]):
|
||||
if st0 is None: st0 = ts
|
||||
if dev == "MARKER":
|
||||
yield {"name":f"--- MARKER {e['name']}", "fmt":f"@ {(ts-st0)*1e-3:9.2f}ms", "ref":None, "ext":None}
|
||||
yield {"device":dev, "name":fmt_colored(e["name"]), "et_ms":ts*1e-3, "ref":None, "ext":None}
|
||||
continue
|
||||
et, timestamp, ext = e["dur"] * 1e-6, (e["st"] - st0 + e["dur"]) * 1e-6, None
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None)
|
||||
if e["fmt"].startswith("TB:"): e["fmt"] = "" # TODO: print python backtrace at a reasonable DEBUG level
|
||||
fmt_str = " ".join(p+" "*max(0, 14-ansilen(p)) for p in e["fmt"].split("\n"))
|
||||
name = f"*** {dev[:7]:7s} "+e["name"]+" "*(46-ansilen(e["name"]))
|
||||
yield {"name":name, "fmt":f"tm {ptm}/{timestamp*1e3:9.2f}ms"+(f" ({fmt_str})" if e["fmt"] else ""), "ref":e["ref"], "ext":ext}
|
||||
yield {"device":dev, "name":fmt_colored(e["name"]), "dur_ms":e["dur"]*1e-3,
|
||||
"et_ms":(e["st"]+e["dur"])*1e-3, "fmt":e["fmt"], "ref":e["ref"], "ext":None}
|
||||
def fmt_top(k:dict) -> str:
|
||||
return f"{fmt_colored(k['name'])}{' ' * max(0, 36-ansilen(k['name']))} {time_to_str(k['dur_ms']*1e-3, w=9)} {k['count']:7d} {k['pct']:6.2f}%"
|
||||
def fmt_all(k:dict) -> str:
|
||||
if k["device"] == "MARKER": return f"--- MARKER {k['name']} /{k['et_ms']:9.2f}ms"
|
||||
ptm = colored(time_to_str(k["dur_ms"]*1e-3, w=9), "yellow" if k["dur_ms"] > 10 else None)
|
||||
fmt_str = " ".join(p+" "*max(0, 14-ansilen(p)) for p in k["fmt"].split("\n"))
|
||||
name = f"*** {k['device'][:7]:7s} "+k["name"]+" "*(46-ansilen(k["name"]))
|
||||
return f"{name} tm {ptm}/{k['et_ms']:9.2f}ms"+(f" ({fmt_str})" if k["fmt"] else "")
|
||||
fmt_row = fmt_top if args.top else fmt_all
|
||||
for k in (produce_top_kernels if args.top else produce_all_kernels)():
|
||||
print(f"{fmt_colored(k['name'])}{' ' * max(0, 36 - ansilen(k['name']))} {k['fmt']}")
|
||||
if args.jsonl: print(json.dumps(k))
|
||||
else: print(fmt_row(k))
|
||||
if k["ref"] is not None:
|
||||
steps = rewrites[viz_data.ctxs[k["ref"]]["name"]]
|
||||
if DEBUG >= 3 and (ast_step:=steps.get("View Base AST")) is not None: print_step(ast_step)
|
||||
@@ -200,6 +205,7 @@ def get_arg_parser() -> argparse.ArgumentParser:
|
||||
default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
g_opts.add_argument("--rewrites-path", type=pathlib.Path, metavar="PATH", help="Path to rewrites.pkl (optional file, default: latest rewrites)",
|
||||
default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
|
||||
g_opts.add_argument("--jsonl", action="store_true", help="Emit profiler output as JSONL")
|
||||
g_opts.add_argument("-h", "--help", action="help", help="show this help message and exit")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -923,6 +923,9 @@ class TestCLI(unittest.TestCase):
|
||||
times = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "ALL", "--top", "-1")
|
||||
self.assertIn("TINY", times)
|
||||
self.assertIn("NULL", times)
|
||||
with Context(DEBUG=3):
|
||||
json_lines = run_cli("--rewrites-path", str(r), "--profile-path", str(p), "-p", "-s", "ALL", "--jsonl")
|
||||
for line in json_lines.split("\n"): _ = json.loads(line)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user