pre extract afters + sched cleanups (#13720)

* pre extract afters + sched cleanups

* claude.md lesson

* tests for schedule cache

* Revert "tests for schedule cache"

This reverts commit fb3f2e800a.
This commit is contained in:
George Hotz
2025-12-16 16:14:30 -04:00
committed by GitHub
parent 4b741e893f
commit ee45669d14
2 changed files with 17 additions and 7 deletions

View File

@@ -134,6 +134,18 @@ The schedule cache strips values from BIND nodes so different bound values (e.g.
- Only extract var_vals when schedule is non-empty (no kernels = no vars needed)
- PatternMatchers are slow to construct - define at module level, not in functions
### Readability Over Speed
Don't add complexity for marginal performance gains. Simpler code that's slightly slower is often better:
```python
# BAD: "optimized" with extra complexity
if has_afters: # skip toposort if no AFTERs
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
# GOOD: simple, always works
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
```
The conditional check adds complexity, potential bugs, and often negligible speedup. Only optimize when profiling shows a real bottleneck.
### Testing LLM Changes
```bash
# Quick smoke test

View File

@@ -162,9 +162,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
pre_schedule, buf_uops_sink = create_schedule(big_sink)
# save in schedule cache
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
combined_sink = UOp.sink(big_sink, tensor_map_sink, buf_uops_sink)
# save in schedule cache (include AFTERs in tensor_map so we don't need big_sink)
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
else:
# schedule cache hit
@@ -174,7 +175,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# replace all the LUNIQUEs with UNIQUEs (single graph_rewrite for everything)
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite combined")
big_sink, tensor_map_sink, buf_uops_sink = combined.src[0], combined.src[1], combined.src[2]
tensor_map_sink, buf_uops_sink = combined.src
tm_src = tensor_map_sink.src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
@@ -207,9 +208,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
# remove all AFTERs, after scheduling, the tensors are just buffers
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\