mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz/cli: simplify to --source and --item flags (#15510)
* viz/cli: simplify to --source and --item flags * update viz cli test
This commit is contained in:
@@ -11,18 +11,18 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server
|
||||
|
||||
## Inspect runtime profiling
|
||||
|
||||
Use `extra/viz/cli.py --profile` to list all traced devices.
|
||||
Use `extra/viz/cli.py --profile` to list all sources.
|
||||
|
||||
List top slowest kernels on a device: `--profile --device "AMD" | head 10`
|
||||
List samples of a kernel on a device: `--profile --device "AMD" --kernel E_3 | head 4`
|
||||
List top slowest kernels on a source: `--profile -s "AMD" | head 10`
|
||||
List samples of a kernel on a source: `--profile -s "AMD" -i E_3 | head 4`
|
||||
|
||||
## Inspect codegen and PatternMatcher
|
||||
|
||||
Use `extra/viz/cli.py --rewrites` to list all traced kernels.
|
||||
Use `extra/viz/cli.py --rewrites` to list all sources.
|
||||
|
||||
List all codegen steps for a kernel: `--rewrites --kernel E_3`
|
||||
Get source code: `--rewrites --kernel E_3 --select "View Source"`
|
||||
Inspect a graph rewrite: `--rewrites --kernel E_3 --select "initial symbolic"`
|
||||
List all codegen steps for a kernel: `--rewrites -s E_3`
|
||||
Get source code: `--rewrites -s E_3 -i "View Source"`
|
||||
Inspect a graph rewrite: `--rewrites -s E_3 -i "initial symbolic"`
|
||||
|
||||
# SQTT tracing
|
||||
|
||||
@@ -31,12 +31,12 @@ Supported on AMD for RDNA3 and RDNA4 (best) and CDNA (developing).
|
||||
Flags: VIZ=-2 to save SQTT trace to a file. VIZ=2 also launches a web server. View other flags in tinygrad/runtime/ops_amd.py to configure SQTT as needed.
|
||||
|
||||
Use `extra/viz/cli.py --profile | grep SQTT` to view all available SQTT traces.
|
||||
You can select a specific trace option with --device, Example workflow:
|
||||
You can select a specific trace with --source, Example workflow:
|
||||
|
||||
```bash
|
||||
# Run amd_asm_matmul with VIZ=-2 to capture the trace
|
||||
VIZ=-2 python extra/gemm/amd_asm_matmul.py
|
||||
|
||||
# View barriers
|
||||
extra/viz/cli.py --profile --device "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10
|
||||
extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10
|
||||
```
|
||||
|
||||
@@ -71,14 +71,14 @@ def main(args) -> None:
|
||||
viz.load_amd_counters(viz.ctxs, profile_data)
|
||||
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
|
||||
if s["name"].startswith("PKTS")}
|
||||
if args.device is None:
|
||||
print("Select a device:")
|
||||
if args.source is None:
|
||||
print("Available sources:")
|
||||
for k in (*profile["layout"], *counters):
|
||||
print(f" {format_colored(k)}")
|
||||
return None
|
||||
|
||||
# ** SQTT printer
|
||||
if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None:
|
||||
if args.source is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.source), None)) is not None:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
|
||||
@@ -108,13 +108,13 @@ def main(args) -> None:
|
||||
# ** Profiler printer
|
||||
agg, total, n = {}, 0, 0
|
||||
for k,v in profile["layout"].items():
|
||||
if not optional_eq({"name":k}, args.device): continue
|
||||
if not optional_eq({"name":k}, args.source): continue
|
||||
print(f" {format_colored(k)}")
|
||||
if args.device is None: continue
|
||||
if args.source is None: continue
|
||||
for e in v.get("events", []):
|
||||
et = e["dur"]*1e-6
|
||||
if args.kernel is not None:
|
||||
if optional_eq(e, args.kernel) and n < 10:
|
||||
if args.item is not None:
|
||||
if optional_eq(e, args.item) and n < 10:
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
name = e["name"]+(" " * (46 - ansilen(e["name"])))
|
||||
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
|
||||
@@ -132,13 +132,13 @@ def main(args) -> None:
|
||||
|
||||
# ** Graph rewrites printer
|
||||
for k in viz.ctxs:
|
||||
if not optional_eq(k, args.kernel): continue
|
||||
if not optional_eq(k, args.source): continue
|
||||
print(k["name"])
|
||||
if args.kernel is None: continue
|
||||
if args.source is None: continue
|
||||
for s in k["steps"]:
|
||||
if not optional_eq(s, args.select): continue
|
||||
if not optional_eq(s, args.item): continue
|
||||
print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else ''))
|
||||
if args.select is not None: print_data(viz.get_render(s['query']))
|
||||
if args.item is not None: print_data(viz.get_render(s['query']))
|
||||
|
||||
def get_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -146,13 +146,9 @@ def get_arg_parser() -> argparse.ArgumentParser:
|
||||
g_mode.add_argument("--profile", action="store_true", help="View profile trace")
|
||||
g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace")
|
||||
g_common = parser.add_argument_group("common options")
|
||||
g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)")
|
||||
g_common.add_argument("-s", "--source", type=str, default=None, metavar="NAME", help="Select a data source (default: list sources)")
|
||||
g_common.add_argument("-i", "--item", type=str, default=None, metavar="NAME", help="Select an item within the source (default: list items)")
|
||||
g_common.add_argument("--no-color", action="store_true", help="Disable colored output")
|
||||
g_profile = parser.add_argument_group("profile options")
|
||||
g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)")
|
||||
g_rewrites = parser.add_argument_group("rewrites options")
|
||||
g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME",
|
||||
help="Select an item within the chosen kernel (optional name, default: only list names)")
|
||||
parser.add_argument("--profile-path", type=pathlib.Path, metavar="PATH", help="Path to profile (optional file, default: latest profile)",
|
||||
default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
parser.add_argument("--rewrites-path", type=pathlib.Path, metavar="PATH", help="Path to rewrites (optional file, default: latest rewrites)",
|
||||
|
||||
@@ -122,7 +122,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
out = run_cli("--profile", "--profile-path", str(pkl_path))
|
||||
sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l]
|
||||
for name in sqtt_traces:
|
||||
out = run_cli("--profile", "--profile-path", str(pkl_path), "--device", name)
|
||||
out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", name)
|
||||
lines = out.split("\n")
|
||||
self.assertIn("Clk", lines[0])
|
||||
for r in lines[2:]:
|
||||
|
||||
Reference in New Issue
Block a user