mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
schedule: cache unbinds for consistent cache keys (#13664)
* schedule: cache unbinds for consistent cache keys strip BIND values before computing cache key so different bound values (e.g. KV cache positions) hit the same schedule cache entry. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * spec: allow single-src BIND for schedule cache key normalization 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * docs: add lessons learned to CLAUDE.md 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * more claude.md --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
44
CLAUDE.md
44
CLAUDE.md
@@ -92,6 +92,12 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
|
||||
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
|
||||
|
||||
## Workflow Rules
|
||||
|
||||
- **NEVER commit without explicit user approval** - always show the diff and wait for approval
|
||||
- Run tests before proposing commits
|
||||
- Test with `SPEC=2` when modifying UOp-related code
|
||||
|
||||
## Style Notes
|
||||
|
||||
- 2-space indentation, 150 char line limit
|
||||
@@ -100,6 +106,44 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
- UOp methods like `.replace()` preserve tags unless explicitly changed
|
||||
- Use `.rtag(value)` to add tags to UOps
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
### UOp ucache Behavior
|
||||
UOps are cached by their contents - creating a UOp with identical (op, dtype, src, arg) returns the **same object**. This means:
|
||||
- `uop.replace(tag=None)` on a tagged UOp returns the original untagged UOp if it exists in cache
|
||||
- Two UOps with same structure are identical (`is` comparison works)
|
||||
|
||||
### Spec Validation
|
||||
When adding new UOp patterns, update `tinygrad/uop/spec.py`. Test with:
|
||||
```bash
|
||||
SPEC=2 python3 test/unit/test_something.py
|
||||
```
|
||||
Spec issues appear as `RuntimeError: SPEC ISSUE None: UOp(...)`.
|
||||
|
||||
### Schedule Cache Key Normalization
|
||||
The schedule cache strips values from BIND nodes so different bound values (e.g., KV cache positions) hit the same cache entry:
|
||||
- `pm_pre_sched_cache`: BIND(DEFINE_VAR, CONST) → BIND(DEFINE_VAR) for cache key
|
||||
- `pm_post_sched_cache`: restores original BIND from context
|
||||
- When accessing `bind.src[1]`, check `len(bind.src) > 1` first (might be stripped)
|
||||
- Extract var_vals from `input_buffers` dict after graph_rewrite (avoids extra toposort)
|
||||
|
||||
### Avoiding Extra Work
|
||||
- Use ctx dict from graph_rewrite to collect info during traversal instead of separate toposort
|
||||
- Only extract var_vals when schedule is non-empty (no kernels = no vars needed)
|
||||
- PatternMatchers are slow to construct - define at module level, not in functions
|
||||
|
||||
### Testing LLM Changes
|
||||
```bash
|
||||
# Quick smoke test
|
||||
echo "Hello" | DEBUG=1 python tinygrad/apps/llm.py --model "llama3.2:1b"
|
||||
|
||||
# Check cache hits (should see "cache hit" after warmup)
|
||||
echo "Hello world" | DEBUG=1 python tinygrad/apps/llm.py --model "llama3.2:1b" 2>&1 | grep cache
|
||||
|
||||
# Test with beam search
|
||||
echo "Hello" | BEAM=2 python tinygrad/apps/llm.py --model "llama3.2:1b"
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Graph Transformation
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, Variable
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
|
||||
class TestScheduleCache(unittest.TestCase):
|
||||
def test_bound_variable_reuses_cache(self):
|
||||
schedule_cache.clear()
|
||||
v = Variable('v', 1, 100)
|
||||
x = Tensor.ones(10).contiguous().realize()
|
||||
|
||||
# first run with v=5
|
||||
t1 = (x + Tensor(v.bind(5))).sum()
|
||||
self.assertEqual(t1.item(), 60.0)
|
||||
cache_size_after_first = len(schedule_cache)
|
||||
|
||||
# second run with v=10 should reuse cache
|
||||
t2 = (x + Tensor(v.bind(10))).sum()
|
||||
self.assertEqual(t2.item(), 110.0)
|
||||
self.assertEqual(len(schedule_cache), cache_size_after_first)
|
||||
|
||||
def test_bound_variable_var_vals(self):
|
||||
v = Variable('pos', 1, 100)
|
||||
x = Tensor.ones(10).contiguous().realize()
|
||||
|
||||
t = x + Tensor(v.bind(42))
|
||||
_, var_vals = t.schedule_with_vars()
|
||||
self.assertEqual(var_vals, {'pos': 42})
|
||||
|
||||
def test_simple(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
b = Tensor.ones(10).contiguous()
|
||||
|
||||
@@ -20,12 +20,11 @@ class ScheduleItem:
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is Ops.RANGE:
|
||||
in_degree.setdefault(u, 0)
|
||||
@@ -44,14 +43,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children.setdefault(ss.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
# for RANGE this is in fixedvars
|
||||
if s.src[1].op is not Ops.RANGE:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
elif s.op in {Ops.BUFFER, Ops.BIND}:
|
||||
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
|
||||
@@ -73,7 +66,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
@@ -108,7 +101,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
else:
|
||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||
sched_ptr += 1
|
||||
return real_schedule, var_vals
|
||||
return real_schedule
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
@@ -129,6 +122,8 @@ pm_pre_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# remove unique consts
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), lambda ctx,b: ctx.setdefault(b, b.replace(src=(b.src[0],)))),
|
||||
])
|
||||
|
||||
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
||||
@@ -141,6 +136,8 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
|
||||
# restore BIND value stripped in pm_pre_sched_cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
|
||||
@@ -149,7 +146,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
# replace all UNIQUE buffers with LUNIQUE
|
||||
# replace all UNIQUE buffers with LUNIQUE, strip BIND values for cache key
|
||||
input_buffers: dict[UOp, UOp] = {}
|
||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
|
||||
sched_cache_key = big_sink_cache.key
|
||||
@@ -187,9 +184,18 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
|
||||
# create the schedule
|
||||
schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||
schedule = create_schedule(big_sink)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
||||
var_vals: dict[str, int] = {}
|
||||
if schedule:
|
||||
for u in input_buffers:
|
||||
if u.op is Ops.BIND:
|
||||
var, val = u.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
|
||||
|
||||
|
||||
@@ -83,6 +83,8 @@ _tensor_spec = PatternMatcher([
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
# single-src BIND used for schedule cache key normalization
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR),), arg=None), lambda: True),
|
||||
|
||||
# device or unique
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user