mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
pre extract afters + sched cleanups (#13720)
* pre extract afters + sched cleanups
* claude.md lesson
* tests for schedule cache
* Revert "tests for schedule cache"
This reverts commit fb3f2e800a.
This commit is contained in:
12
CLAUDE.md
12
CLAUDE.md
@@ -134,6 +134,18 @@ The schedule cache strips values from BIND nodes so different bound values (e.g.
|
||||
- Only extract var_vals when schedule is non-empty (no kernels = no vars needed)
|
||||
- PatternMatchers are slow to construct - define at module level, not in functions
|
||||
|
||||
### Readability Over Speed
|
||||
Don't add complexity for marginal performance gains. Simpler code that's slightly slower is often better:
|
||||
```python
|
||||
# BAD: "optimized" with extra complexity
|
||||
if has_afters: # skip toposort if no AFTERs
|
||||
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
|
||||
|
||||
# GOOD: simple, always works
|
||||
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
|
||||
```
|
||||
The conditional check adds complexity, potential bugs, and often negligible speedup. Only optimize when profiling shows a real bottleneck.
|
||||
|
||||
### Testing LLM Changes
|
||||
```bash
|
||||
# Quick smoke test
|
||||
|
||||
@@ -162,9 +162,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
|
||||
pre_schedule, buf_uops_sink = create_schedule(big_sink)
|
||||
|
||||
# save in schedule cache
|
||||
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
|
||||
combined_sink = UOp.sink(big_sink, tensor_map_sink, buf_uops_sink)
|
||||
# save in schedule cache (include AFTERs in tensor_map so we don't need big_sink)
|
||||
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
|
||||
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
|
||||
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
|
||||
schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
|
||||
else:
|
||||
# schedule cache hit
|
||||
@@ -174,7 +175,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
# replace all the LUNIQUEs with UNIQUEs (single graph_rewrite for everything)
|
||||
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
|
||||
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite combined")
|
||||
big_sink, tensor_map_sink, buf_uops_sink = combined.src[0], combined.src[1], combined.src[2]
|
||||
tensor_map_sink, buf_uops_sink = combined.src
|
||||
tm_src = tensor_map_sink.src
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
|
||||
@@ -207,9 +208,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
|
||||
Reference in New Issue
Block a user