diff --git a/CLAUDE.md b/CLAUDE.md index 1161db0235..25b7c5f204 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -192,9 +192,12 @@ When optimizing tinygrad internals: 9. **Avoid creating intermediate objects in hot paths** - For example, `any(x.op in ops for x in self.backward_slice)` is faster than `any(x.op in ops for x in {self:None, **self.backward_slice})` because it avoids dict creation. -## Pattern Matching Profiling +## Pattern Matching Analysis -Use `TRACK_MATCH_STATS=2` to identify expensive patterns: +**Use the right tool:** + +- `TRACK_MATCH_STATS=2` - **Profiling**: identify expensive patterns +- `VIZ=-1` - **Inspection**: see all transformations, what every match pattern does, the before/after diffs ```bash TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py @@ -209,6 +212,14 @@ Key patterns to watch (from ResNet50 benchmark): Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose. +```bash +# Save the trace +VIZ=-1 python test/test_tiny.py TestTiny.test_gemm + +# Explore it +./extra/viz/cli.py --help +``` + ## AMD Performance Counter Profiling Set VIZ to `-2` to save performance counters traces for the AMD backend. diff --git a/extra/viz/cli.py b/extra/viz/cli.py new file mode 100755 index 0000000000..06885c5412 --- /dev/null +++ b/extra/viz/cli.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +import argparse, pathlib +from typing import Iterator +from tinygrad.viz import serve as viz +from tinygrad.uop.ops import RewriteTrace +from tinygrad.helpers import temp, ansistrip, colored + +def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip(val["name"]) == arg + +def print_data(data:dict) -> None: + if isinstance(data.get("value"), Iterator): + for m in data["value"]: + if not m["diff"]: continue + fp = pathlib.Path(m["upat"][0][0]) + print(f"{fp.parent.name}/{fp.name}:{m['upat'][0][1]}") + print(m["upat"][1]) + for line in m["diff"]: + color = "red" if line.startswith("-") else "green" if line.startswith("+") else None + print(colored(line, color)) + if data.get("src") is not None: print(data["src"]) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Select a kernel by name (optional name, default: only list names)') + parser.add_argument('--select', type=str, default=None, metavar="NAME", + help='Select an item within the chosen kernel (optional name, default: only list names)') + args = parser.parse_args() + + viz.trace = viz.load_pickle(pathlib.Path(temp("rewrites.pkl", append_user=True)), default=RewriteTrace([], [], {})) + viz.ctxs = viz.get_rewrites(viz.trace) + for k in viz.ctxs: + if not optional_eq(k, args.kernel): continue + print(k["name"]) + if args.kernel is None: continue + for s in k["steps"]: + if not optional_eq(s, args.select): continue + print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else '')) + if args.select is not None: print_data(viz.get_render(s['query']))