schedule: cache unbinds for consistent cache keys (#13664)

* schedule: cache unbinds for consistent cache keys

strip BIND values before computing cache key so different bound values
(e.g. KV cache positions) hit the same schedule cache entry.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* spec: allow single-src BIND for schedule cache key normalization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* docs: add lessons learned to CLAUDE.md

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* more claude.md

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz
2025-12-12 17:27:42 -05:00
committed by GitHub
parent 27845353a0
commit 55845f7de7
4 changed files with 90 additions and 15 deletions

View File

@@ -92,6 +92,12 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
## Workflow Rules
- **NEVER commit without explicit user approval** - always show the diff and wait for approval
- Run tests before proposing commits
- Test with `SPEC=2` when modifying UOp-related code
## Style Notes
- 2-space indentation, 150 char line limit
@@ -100,6 +106,44 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
- UOp methods like `.replace()` preserve tags unless explicitly changed
- Use `.rtag(value)` to add tags to UOps
## Lessons Learned
### UOp ucache Behavior
UOps are cached by their contents - creating a UOp with identical (op, dtype, src, arg) returns the **same object**. This means:
- `uop.replace(tag=None)` on a tagged UOp returns the original untagged UOp if it exists in cache
- Two UOps with same structure are identical (`is` comparison works)
### Spec Validation
When adding new UOp patterns, update `tinygrad/uop/spec.py`. Test with:
```bash
SPEC=2 python3 test/unit/test_something.py
```
Spec issues appear as `RuntimeError: SPEC ISSUE None: UOp(...)`.
### Schedule Cache Key Normalization
The schedule cache strips values from BIND nodes so different bound values (e.g., KV cache positions) hit the same cache entry:
- `pm_pre_sched_cache`: BIND(DEFINE_VAR, CONST) → BIND(DEFINE_VAR) for cache key
- `pm_post_sched_cache`: restores original BIND from context
- When accessing `bind.src[1]`, check `len(bind.src) > 1` first (might be stripped)
- Extract var_vals from `input_buffers` dict after graph_rewrite (avoids extra toposort)
### Avoiding Extra Work
- Use ctx dict from graph_rewrite to collect info during traversal instead of separate toposort
- Only extract var_vals when schedule is non-empty (no kernels = no vars needed)
- PatternMatchers are slow to construct - define at module level, not in functions
### Testing LLM Changes
```bash
# Quick smoke test
echo "Hello" | DEBUG=1 python tinygrad/apps/llm.py --model "llama3.2:1b"
# Check cache hits (should see "cache hit" after warmup)
echo "Hello world" | DEBUG=1 python tinygrad/apps/llm.py --model "llama3.2:1b" 2>&1 | grep cache
# Test with beam search
echo "Hello" | BEAM=2 python tinygrad/apps/llm.py --model "llama3.2:1b"
```
## Common Patterns
### Graph Transformation

View File

@@ -1,8 +1,31 @@
import unittest
from tinygrad import Tensor
from tinygrad import Tensor, Variable
from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase):
def test_bound_variable_reuses_cache(self):
schedule_cache.clear()
v = Variable('v', 1, 100)
x = Tensor.ones(10).contiguous().realize()
# first run with v=5
t1 = (x + Tensor(v.bind(5))).sum()
self.assertEqual(t1.item(), 60.0)
cache_size_after_first = len(schedule_cache)
# second run with v=10 should reuse cache
t2 = (x + Tensor(v.bind(10))).sum()
self.assertEqual(t2.item(), 110.0)
self.assertEqual(len(schedule_cache), cache_size_after_first)
def test_bound_variable_var_vals(self):
v = Variable('pos', 1, 100)
x = Tensor.ones(10).contiguous().realize()
t = x + Tensor(v.bind(42))
_, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42})
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()

View File

@@ -20,12 +20,11 @@ class ScheduleItem:
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
with cpu_profile(TracingKey("toposort sched_sink")):
# construct the KERNEL children graph based on assigns
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort():
if u.op is Ops.RANGE:
in_degree.setdefault(u, 0)
@@ -44,14 +43,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
children.setdefault(ss.src[1], []).append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
elif s.op is Ops.BIND:
# for RANGE this is in fixedvars
if s.src[1].op is not Ops.RANGE:
var, val = s.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
elif s.op in {Ops.BUFFER, Ops.BIND}:
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
else:
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
@@ -73,7 +66,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
@@ -108,7 +101,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
else:
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
sched_ptr += 1
return real_schedule, var_vals
return real_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
@@ -129,6 +122,8 @@ pm_pre_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), lambda ctx,b: ctx.setdefault(b, b.replace(src=(b.src[0],)))),
])
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
@@ -141,6 +136,8 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
# restore BIND value stripped in pm_pre_sched_cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
])
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@@ -149,7 +146,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# big_sink srcs are all the Tensors
st = time.perf_counter()
# replace all UNIQUE buffers with LUNIQUE
# replace all UNIQUE buffers with LUNIQUE, strip BIND values for cache key
input_buffers: dict[UOp, UOp] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
sched_cache_key = big_sink_cache.key
@@ -187,9 +184,18 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule
schedule, var_vals = create_schedule_with_vars(big_sink)
schedule = create_schedule(big_sink)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
# extract var_vals from BINDs that were stripped (only if there are kernels)
var_vals: dict[str, int] = {}
if schedule:
for u in input_buffers:
if u.op is Ops.BIND:
var, val = u.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
# remove all AFTERs, after scheduling, the tensors are just buffers
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}

View File

@@ -83,6 +83,8 @@ _tensor_spec = PatternMatcher([
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
# single-src BIND used for schedule cache key normalization
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR),), arg=None), lambda: True),
# device or unique
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),