speedups: early return from simplify (#13665)

* early return from simplify

* pm_rewrite

* more speed

* remove again

* early return from simplify

* ugh
This commit is contained in:
George Hotz
2025-12-14 00:51:28 -05:00
committed by GitHub
parent bcbf832399
commit d7fb5d9b62
2 changed files with 46 additions and 7 deletions

View File

@@ -172,3 +172,42 @@ var, val = bind_uop.unbind()
# Shapes can be symbolic (contain UOps)
shape = tensor.shape # tuple[sint, ...] where sint = int | UOp
```
## Performance Optimization
When optimizing tinygrad internals:
1. **Measure wall time, not just call counts** - Reducing `graph_rewrite` calls doesn't always improve wall time. The overhead of conditional checks can exceed the cost of the operation being skipped.
2. **Profile each optimization individually** - Run benchmarks with and without each change to measure actual impact. Use `test/external/external_benchmark_schedule.py` for schedule/rewrite timing.
3. **Early exits in hot paths are effective** - Simple checks like `if self.op is Ops.CONST: return self` in `simplify()` can eliminate many unnecessary `graph_rewrite` calls.
4. **`graph_rewrite` is expensive** - Each call has overhead even for small graphs. Avoid calling it when the result is trivially known (e.g., simplifying a CONST returns itself).
5. **Beware iterator overhead** - Checks like `all(x.op is Ops.CONST for x in self.src)` can be slower than just running the operation, especially for small sequences.
6. **Verify cache hit rates before adding/keeping caches** - Measure actual hit rates with real workloads. A cache with 0% hit rate is pure overhead (e.g., `pm_cache` was removed because the algorithm guarantees each UOp is only passed to `pm_rewrite` once).
7. **Use `TRACK_MATCH_STATS=2` to profile pattern matching** - This shows match rates and time per pattern. Look for patterns with 0% match rate that still cost significant time - these are pure overhead for that workload.
8. **Cached properties beat manual traversal** - `backward_slice` uses `@functools.cached_property`. A DFS with early-exit sounds faster but is actually slower because it doesn't benefit from caching. The cache hit benefit often outweighs algorithmic improvements.
9. **Avoid creating intermediate objects in hot paths** - For example, `any(x.op in ops for x in self.backward_slice)` is faster than `any(x.op in ops for x in {self:None, **self.backward_slice})` because it avoids dict creation.
## Pattern Matching Profiling
Use `TRACK_MATCH_STATS=2` to identify expensive patterns:
```bash
TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py
```
Output format: `matches / attempts -- match_time / total_time ms -- location`
Key patterns to watch (from ResNet50 benchmark):
- `split_load_store`: ~146ms, 31% match rate - does real work
- `simplify_valid`: ~75ms, 0% match rate in this workload - checks AND ops for INDEX in backward slice
- `vmin==vmax folding`: ~55ms, 0.33% match rate - checks 52K ops but rarely matches
Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose.

View File

@@ -158,7 +158,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@property
def backward_slice_with_self(self:UOp) -> dict[UOp, None]: return {self:None, **self.backward_slice}
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
def op_in_backward_slice_with_self(self, *ops:Ops) -> bool:
# Check self first, then iterate backward_slice (avoids creating intermediate dict)
return self.op in ops or any(x.op in ops for x in self.backward_slice)
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
cache: dict[UOp, None] = {}
@@ -340,6 +342,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# *** uop evaluation ***
def simplify(self, tracked=False):
if self.op in {Ops.CONST, Ops.VCONST}: return self
# late import!
from tinygrad.uop.symbolic import symbolic
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
@@ -1191,16 +1194,13 @@ class BottomUpGate(Exception): pass
class RewriteContext:
def __init__(self, pm, bpm, ctx=None):
self.pm: PatternMatcher|None = pm
self.pm_cache: dict[UOp, UOp|None] = {}
self.bpm: PatternMatcher|None = bpm
self.bpm_cache: dict[UOp, UOp|None] = {}
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
def cached_pm_rewrite(self, x:UOp) -> UOp|None:
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.pm_cache[x] = unwrap(self.pm).rewrite(x, self.ctx)
return ret
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
def cached_bpm_rewrite(self, x:UOp) -> UOp|None:
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
@@ -1247,7 +1247,7 @@ class RewriteContext:
# in stage 1, once all srcs are rewritten, rebuild (if changed) or run top-down rewrite
if (new_src:=tuple(tmp)) == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
if self.pm is None or (new_src_n:=self.pm_rewrite(new_n)) is None:
self.replace[n] = new_n
continue
else: