diff --git a/CLAUDE.md b/CLAUDE.md index eb262afc0f..b6b660c5d7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -134,6 +134,18 @@ The schedule cache strips values from BIND nodes so different bound values (e.g. - Only extract var_vals when schedule is non-empty (no kernels = no vars needed) - PatternMatchers are slow to construct - define at module level, not in functions +### Readability Over Speed +Don't add complexity for marginal performance gains. Simpler code that's slightly slower is often better: +```python +# BAD: "optimized" with extra complexity +if has_afters: # skip toposort if no AFTERs + after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER] + +# GOOD: simple, always works +after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER] +``` +The conditional check adds complexity, potential bugs, and often negligible speedup. Only optimize when profiling shows a real bottleneck. + ### Testing LLM Changes ```bash # Quick smoke test diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ae25f79bcb..d2ee084711 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -162,9 +162,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li pre_schedule, buf_uops_sink = create_schedule(big_sink) - # save in schedule cache - tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()])) - combined_sink = UOp.sink(big_sink, tensor_map_sink, buf_uops_sink) + # save in schedule cache (include AFTERs in tensor_map so we don't need big_sink) + after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER] + tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map)) + combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink) schedule_cache[sched_cache_key] = (pre_schedule, combined_sink) else: # schedule cache hit @@ -174,7 +175,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li # replace all the LUNIQUEs with UNIQUEs (single graph_rewrite for everything) input_buffers_reverse = {v:k for k,v in input_buffers.items()} combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite combined") - big_sink, tensor_map_sink, buf_uops_sink = combined.src[0], combined.src[1], combined.src[2] + tensor_map_sink, buf_uops_sink = combined.src tm_src = tensor_map_sink.src tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)} @@ -207,9 +208,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" var_vals[var.expr] = val - # remove all AFTERs, after scheduling, the tensors are just buffers - tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER} - if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\