mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
speedups: early return from simplify (#13665)
* early return from simplify * pm_rewrite * more speed * remove again * early return from simplify * ugh
This commit is contained in:
39
CLAUDE.md
39
CLAUDE.md
@@ -172,3 +172,42 @@ var, val = bind_uop.unbind()
|
||||
# Shapes can be symbolic (contain UOps)
|
||||
shape = tensor.shape # tuple[sint, ...] where sint = int | UOp
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
When optimizing tinygrad internals:
|
||||
|
||||
1. **Measure wall time, not just call counts** - Reducing `graph_rewrite` calls doesn't always improve wall time. The overhead of conditional checks can exceed the cost of the operation being skipped.
|
||||
|
||||
2. **Profile each optimization individually** - Run benchmarks with and without each change to measure actual impact. Use `test/external/external_benchmark_schedule.py` for schedule/rewrite timing.
|
||||
|
||||
3. **Early exits in hot paths are effective** - Simple checks like `if self.op is Ops.CONST: return self` in `simplify()` can eliminate many unnecessary `graph_rewrite` calls.
|
||||
|
||||
4. **`graph_rewrite` is expensive** - Each call has overhead even for small graphs. Avoid calling it when the result is trivially known (e.g., simplifying a CONST returns itself).
|
||||
|
||||
5. **Beware iterator overhead** - Checks like `all(x.op is Ops.CONST for x in self.src)` can be slower than just running the operation, especially for small sequences.
|
||||
|
||||
6. **Verify cache hit rates before adding/keeping caches** - Measure actual hit rates with real workloads. A cache with 0% hit rate is pure overhead (e.g., `pm_cache` was removed because the algorithm guarantees each UOp is only passed to `pm_rewrite` once).
|
||||
|
||||
7. **Use `TRACK_MATCH_STATS=2` to profile pattern matching** - This shows match rates and time per pattern. Look for patterns with 0% match rate that still cost significant time - these are pure overhead for that workload.
|
||||
|
||||
8. **Cached properties beat manual traversal** - `backward_slice` uses `@functools.cached_property`. A DFS with early-exit sounds faster but is actually slower because it doesn't benefit from caching. The cache hit benefit often outweighs algorithmic improvements.
|
||||
|
||||
9. **Avoid creating intermediate objects in hot paths** - For example, `any(x.op in ops for x in self.backward_slice)` is faster than `any(x.op in ops for x in {self:None, **self.backward_slice})` because it avoids dict creation.
|
||||
|
||||
## Pattern Matching Profiling
|
||||
|
||||
Use `TRACK_MATCH_STATS=2` to identify expensive patterns:
|
||||
|
||||
```bash
|
||||
TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py
|
||||
```
|
||||
|
||||
Output format: `matches / attempts -- match_time / total_time ms -- location`
|
||||
|
||||
Key patterns to watch (from ResNet50 benchmark):
|
||||
- `split_load_store`: ~146ms, 31% match rate - does real work
|
||||
- `simplify_valid`: ~75ms, 0% match rate in this workload - checks AND ops for INDEX in backward slice
|
||||
- `vmin==vmax folding`: ~55ms, 0.33% match rate - checks 52K ops but rarely matches
|
||||
|
||||
Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose.
|
||||
|
||||
@@ -158,7 +158,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
||||
@property
|
||||
def backward_slice_with_self(self:UOp) -> dict[UOp, None]: return {self:None, **self.backward_slice}
|
||||
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
|
||||
def op_in_backward_slice_with_self(self, *ops:Ops) -> bool:
|
||||
# Check self first, then iterate backward_slice (avoids creating intermediate dict)
|
||||
return self.op in ops or any(x.op in ops for x in self.backward_slice)
|
||||
|
||||
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
||||
cache: dict[UOp, None] = {}
|
||||
@@ -340,6 +342,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self, tracked=False):
|
||||
if self.op in {Ops.CONST, Ops.VCONST}: return self
|
||||
# late import!
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
|
||||
@@ -1191,16 +1194,13 @@ class BottomUpGate(Exception): pass
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, bpm, ctx=None):
|
||||
self.pm: PatternMatcher|None = pm
|
||||
self.pm_cache: dict[UOp, UOp|None] = {}
|
||||
self.bpm: PatternMatcher|None = bpm
|
||||
self.bpm_cache: dict[UOp, UOp|None] = {}
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
|
||||
def cached_pm_rewrite(self, x:UOp) -> UOp|None:
|
||||
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
|
||||
ret = self.pm_cache[x] = unwrap(self.pm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
|
||||
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
|
||||
|
||||
def cached_bpm_rewrite(self, x:UOp) -> UOp|None:
|
||||
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
|
||||
@@ -1247,7 +1247,7 @@ class RewriteContext:
|
||||
# in stage 1, once all srcs are rewritten, rebuild (if changed) or run top-down rewrite
|
||||
if (new_src:=tuple(tmp)) == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
||||
if self.pm is None or (new_src_n:=self.pm_rewrite(new_n)) is None:
|
||||
self.replace[n] = new_n
|
||||
continue
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user