From d7fb5d9b624aa9ce00883dbf6b1cf2f9d9984ce4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 14 Dec 2025 00:51:28 -0500 Subject: [PATCH] speedups: early return from simplify (#13665) * early return from simplify * pm_rewrite * more speed * remove again * early return from simplify * ugh --- CLAUDE.md | 39 +++++++++++++++++++++++++++++++++++++++ tinygrad/uop/ops.py | 14 +++++++------- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 7ef20d2291..858e5e9204 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -172,3 +172,42 @@ var, val = bind_uop.unbind() # Shapes can be symbolic (contain UOps) shape = tensor.shape # tuple[sint, ...] where sint = int | UOp ``` + +## Performance Optimization + +When optimizing tinygrad internals: + +1. **Measure wall time, not just call counts** - Reducing `graph_rewrite` calls doesn't always improve wall time. The overhead of conditional checks can exceed the cost of the operation being skipped. + +2. **Profile each optimization individually** - Run benchmarks with and without each change to measure actual impact. Use `test/external/external_benchmark_schedule.py` for schedule/rewrite timing. + +3. **Early exits in hot paths are effective** - Simple checks like `if self.op is Ops.CONST: return self` in `simplify()` can eliminate many unnecessary `graph_rewrite` calls. + +4. **`graph_rewrite` is expensive** - Each call has overhead even for small graphs. Avoid calling it when the result is trivially known (e.g., simplifying a CONST returns itself). + +5. **Beware iterator overhead** - Checks like `all(x.op is Ops.CONST for x in self.src)` can be slower than just running the operation, especially for small sequences. + +6. **Verify cache hit rates before adding/keeping caches** - Measure actual hit rates with real workloads. A cache with 0% hit rate is pure overhead (e.g., `pm_cache` was removed because the algorithm guarantees each UOp is only passed to `pm_rewrite` once). + +7. **Use `TRACK_MATCH_STATS=2` to profile pattern matching** - This shows match rates and time per pattern. Look for patterns with 0% match rate that still cost significant time - these are pure overhead for that workload. + +8. **Cached properties beat manual traversal** - `backward_slice` uses `@functools.cached_property`. A DFS with early-exit sounds faster but is actually slower because it doesn't benefit from caching. The cache hit benefit often outweighs algorithmic improvements. + +9. **Avoid creating intermediate objects in hot paths** - For example, `any(x.op in ops for x in self.backward_slice)` is faster than `any(x.op in ops for x in {self:None, **self.backward_slice})` because it avoids dict creation. + +## Pattern Matching Profiling + +Use `TRACK_MATCH_STATS=2` to identify expensive patterns: + +```bash +TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py +``` + +Output format: `matches / attempts -- match_time / total_time ms -- location` + +Key patterns to watch (from ResNet50 benchmark): +- `split_load_store`: ~146ms, 31% match rate - does real work +- `simplify_valid`: ~75ms, 0% match rate in this workload - checks AND ops for INDEX in backward slice +- `vmin==vmax folding`: ~55ms, 0.33% match rate - checks 52K ops but rarely matches + +Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2171c87b4f..f241c81b13 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -158,7 +158,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass): @property def backward_slice_with_self(self:UOp) -> dict[UOp, None]: return {self:None, **self.backward_slice} - def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self) + def op_in_backward_slice_with_self(self, *ops:Ops) -> bool: + # Check self first, then iterate backward_slice (avoids creating intermediate dict) + return self.op in ops or any(x.op in ops for x in self.backward_slice) def toposort(self, gate:Callable|None=None) -> dict[UOp, None]: cache: dict[UOp, None] = {} @@ -340,6 +342,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # *** uop evaluation *** def simplify(self, tracked=False): + if self.op in {Ops.CONST, Ops.VCONST}: return self # late import! from tinygrad.uop.symbolic import symbolic with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value): @@ -1191,16 +1194,13 @@ class BottomUpGate(Exception): pass class RewriteContext: def __init__(self, pm, bpm, ctx=None): self.pm: PatternMatcher|None = pm - self.pm_cache: dict[UOp, UOp|None] = {} self.bpm: PatternMatcher|None = bpm self.bpm_cache: dict[UOp, UOp|None] = {} self.ctx = ctx self.replace: dict[UOp, UOp] = {} - def cached_pm_rewrite(self, x:UOp) -> UOp|None: - if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret - ret = self.pm_cache[x] = unwrap(self.pm).rewrite(x, self.ctx) - return ret + # no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite + def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx) def cached_bpm_rewrite(self, x:UOp) -> UOp|None: if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret @@ -1247,7 +1247,7 @@ class RewriteContext: # in stage 1, once all srcs are rewritten, rebuild (if changed) or run top-down rewrite if (new_src:=tuple(tmp)) == new_n.src: # if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict - if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None: + if self.pm is None or (new_src_n:=self.pm_rewrite(new_n)) is None: self.replace[n] = new_n continue else: