viz/cli: simplify to --source and --item flags (#15510)

* viz/cli: simplify to --source and --item flags

* update viz cli test
This commit is contained in:
qazal
2026-03-27 21:46:39 +02:00
committed by GitHub
parent 0d6fc0f571
commit dcc2a5d23b
3 changed files with 23 additions and 27 deletions

View File

@@ -11,18 +11,18 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server
## Inspect runtime profiling
Use `extra/viz/cli.py --profile` to list all traced devices.
Use `extra/viz/cli.py --profile` to list all sources.
List top slowest kernels on a device: `--profile --device "AMD" | head 10`
List samples of a kernel on a device: `--profile --device "AMD" --kernel E_3 | head 4`
List top slowest kernels on a source: `--profile -s "AMD" | head 10`
List samples of a kernel on a source: `--profile -s "AMD" -i E_3 | head 4`
## Inspect codegen and PatternMatcher
Use `extra/viz/cli.py --rewrites` to list all traced kernels.
Use `extra/viz/cli.py --rewrites` to list all sources.
List all codegen steps for a kernel: `--rewrites --kernel E_3`
Get source code: `--rewrites --kernel E_3 --select "View Source"`
Inspect a graph rewrite: `--rewrites --kernel E_3 --select "initial symbolic"`
List all codegen steps for a kernel: `--rewrites -s E_3`
Get source code: `--rewrites -s E_3 -i "View Source"`
Inspect a graph rewrite: `--rewrites -s E_3 -i "initial symbolic"`
# SQTT tracing
@@ -31,12 +31,12 @@ Supported on AMD for RDNA3 and RDNA4 (best) and CDNA (developing).
Flags: VIZ=-2 to save SQTT trace to a file. VIZ=2 also launches a web server. View other flags in tinygrad/runtime/ops_amd.py to configure SQTT as needed.
Use `extra/viz/cli.py --profile | grep SQTT` to view all available SQTT traces.
You can select a specific trace option with --device, Example workflow:
You can select a specific trace with --source, Example workflow:
```bash
# Run amd_asm_matmul with VIZ=-2 to capture the trace
VIZ=-2 python extra/gemm/amd_asm_matmul.py
# View barriers
extra/viz/cli.py --profile --device "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10
extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10
```

View File

@@ -71,14 +71,14 @@ def main(args) -> None:
viz.load_amd_counters(viz.ctxs, profile_data)
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
if s["name"].startswith("PKTS")}
if args.device is None:
print("Select a device:")
if args.source is None:
print("Available sources:")
for k in (*profile["layout"], *counters):
print(f" {format_colored(k)}")
return None
# ** SQTT printer
if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None:
if args.source is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.source), None)) is not None:
# modern terminals support 24-bit color
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
@@ -108,13 +108,13 @@ def main(args) -> None:
# ** Profiler printer
agg, total, n = {}, 0, 0
for k,v in profile["layout"].items():
if not optional_eq({"name":k}, args.device): continue
if not optional_eq({"name":k}, args.source): continue
print(f" {format_colored(k)}")
if args.device is None: continue
if args.source is None: continue
for e in v.get("events", []):
et = e["dur"]*1e-6
if args.kernel is not None:
if optional_eq(e, args.kernel) and n < 10:
if args.item is not None:
if optional_eq(e, args.item) and n < 10:
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
name = e["name"]+(" " * (46 - ansilen(e["name"])))
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
@@ -132,13 +132,13 @@ def main(args) -> None:
# ** Graph rewrites printer
for k in viz.ctxs:
if not optional_eq(k, args.kernel): continue
if not optional_eq(k, args.source): continue
print(k["name"])
if args.kernel is None: continue
if args.source is None: continue
for s in k["steps"]:
if not optional_eq(s, args.select): continue
if not optional_eq(s, args.item): continue
print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else ''))
if args.select is not None: print_data(viz.get_render(s['query']))
if args.item is not None: print_data(viz.get_render(s['query']))
def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
@@ -146,13 +146,9 @@ def get_arg_parser() -> argparse.ArgumentParser:
g_mode.add_argument("--profile", action="store_true", help="View profile trace")
g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace")
g_common = parser.add_argument_group("common options")
g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)")
g_common.add_argument("-s", "--source", type=str, default=None, metavar="NAME", help="Select a data source (default: list sources)")
g_common.add_argument("-i", "--item", type=str, default=None, metavar="NAME", help="Select an item within the source (default: list items)")
g_common.add_argument("--no-color", action="store_true", help="Disable colored output")
g_profile = parser.add_argument_group("profile options")
g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)")
g_rewrites = parser.add_argument_group("rewrites options")
g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME",
help="Select an item within the chosen kernel (optional name, default: only list names)")
parser.add_argument("--profile-path", type=pathlib.Path, metavar="PATH", help="Path to profile (optional file, default: latest profile)",
default=pathlib.Path(temp("profile.pkl", append_user=True)))
parser.add_argument("--rewrites-path", type=pathlib.Path, metavar="PATH", help="Path to rewrites (optional file, default: latest rewrites)",

View File

@@ -122,7 +122,7 @@ class TestSQTTMapBase(unittest.TestCase):
out = run_cli("--profile", "--profile-path", str(pkl_path))
sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l]
for name in sqtt_traces:
out = run_cli("--profile", "--profile-path", str(pkl_path), "--device", name)
out = run_cli("--profile", "--profile-path", str(pkl_path), "--source", name)
lines = out.split("\n")
self.assertIn("Clk", lines[0])
for r in lines[2:]: