schedule: cache unbinds for consistent cache keys (#13662)

* schedule: cache unbinds for consistent cache keys

different bound variable values (e.g. kv cache positions) now produce
the same schedule cache key by unbinding BIND(DEFINE_VAR, CONST) before
computing the cache key and rebinding after lookup.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* schedule: cache unbinds for consistent cache keys

When scheduling, BIND(DEFINE_VAR, CONST) nodes are now unbound to
tagged DEFINE_VARs before computing the cache key. This ensures that
the same computation with different bound values (e.g., different
KV cache positions in LLM) gets the same cache key and reuses the
cached schedule.

The fix:
- pm_pre_sched_cache: replaces BIND with tagged DEFINE_VAR
- pm_post_sched_cache: restores tagged DEFINE_VAR back to original BIND
- pm_remove_rangeify_tags: excludes DEFINE_VAR to preserve tags through rangeify
- var_vals extracted from BINDs before cache key computation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* schedule: fix BIND handling and add CLAUDE.md

- Handle BIND to RANGE in create_schedule (not matched by CONST pattern)
- Assert all BINDs on same variable have same value
- Add CLAUDE.md codebase guide

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz
2025-12-12 16:40:10 -05:00
committed by GitHub
parent fcaed1e1dd
commit af86cae10c
6 changed files with 181 additions and 17 deletions

130
CLAUDE.md Normal file
View File

@@ -0,0 +1,130 @@
# Claude Code Guide for tinygrad
## Architecture Overview
tinygrad compiles tensor operations into optimized kernels. The pipeline:
1. **Tensor** (`tensor.py`) - User-facing API, creates UOp graph
2. **UOp** (`uop/ops.py`) - Unified IR for all operations (both tensor and kernel level)
3. **Schedule** (`engine/schedule.py`, `schedule/`) - Converts tensor UOps to kernel UOps
4. **Codegen** (`codegen/`) - Converts kernel UOps to device code
5. **Runtime** (`runtime/`) - Device-specific execution
## Key Concepts
### UOp (Universal Operation)
Everything is a UOp - tensors, operations, buffers, kernels. Key properties:
- `op`: The operation type (Ops enum)
- `dtype`: Data type
- `src`: Tuple of source UOps
- `arg`: Operation-specific argument
- `tag`: Optional tag for graph transformations
UOps are **immutable and cached** - creating the same UOp twice returns the same object (ucache).
### PatternMatcher
Used extensively for graph transformations:
```python
pm = PatternMatcher([
(UPat(Ops.ADD, src=(UPat.cvar("x"), UPat.cvar("x"))), lambda x: x * 2),
])
result = graph_rewrite(uop, pm)
```
### Schedule Cache
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
## Directory Structure
```
tinygrad/
├── tensor.py # Tensor class, user API
├── device.py # Buffer, device management
├── dtype.py # Data types
├── helpers.py # Utilities, environment vars
├── uop/
│ ├── ops.py # UOp class, Ops enum, PatternMatcher
│ ├── spec.py # UOp type verification
│ └── symbolic.py # Symbolic math simplification
├── engine/
│ ├── schedule.py # Schedule creation, caching
│ ├── realize.py # Tensor realization
│ ├── jit.py # JIT compilation
│ └── memory.py # Memory planning
├── schedule/
│ ├── rangeify.py # Convert movements to ranges
│ └── indexing.py # Index calculations
├── codegen/
│ ├── kernel.py # Kernel optimization
│ └── uopgraph.py # UOp graph transformations
├── renderer/ # Code generation (CUDA, Metal, etc.)
└── runtime/ # Device backends
```
## Testing
```bash
# Run specific test
python -m pytest test/unit/test_schedule_cache.py -xvs
# Run with timeout
python -m pytest test/test_symbolic_ops.py -x --timeout=60
# Debug with print
DEBUG=2 python -m pytest test/test_schedule.py::test_name -xvs
# Visualize UOp graphs
VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
```
## Common Environment Variables
- `DEBUG=1-4` - Increasing verbosity
- `VIZ=1` - Enable graph visualization
- `SPEC=1` - Enable UOp spec verification
- `NOOPT=1` - Disable optimizations
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
## Debugging Tips
1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())`
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
## Style Notes
- 2-space indentation, 150 char line limit
- PatternMatchers should be defined at module level (slow to construct)
- Prefer `graph_rewrite` over manual graph traversal
- UOp methods like `.replace()` preserve tags unless explicitly changed
- Use `.rtag(value)` to add tags to UOps
## Common Patterns
### Graph Transformation
```python
def my_transform(ctx, x):
# Return new UOp or None to skip
return x.replace(arg=new_arg)
pm = PatternMatcher([
(UPat(Ops.SOMETHING, name="x"), my_transform),
])
result = graph_rewrite(input_uop, pm, ctx={})
```
### Finding Variables
```python
# Get all variables in a UOp graph
variables = uop.variables()
# Get bound variable values
var, val = bind_uop.unbind()
```
### Shape Handling
```python
# Shapes can be symbolic (contain UOps)
shape = tensor.shape # tuple[sint, ...] where sint = int | UOp
```

View File

@@ -1,8 +1,31 @@
import unittest
from tinygrad import Tensor
from tinygrad import Tensor, Variable
from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase):
def test_bound_variable_reuses_cache(self):
schedule_cache.clear()
v = Variable('v', 1, 100)
x = Tensor.ones(10).contiguous().realize()
# first run with v=5
t1 = (x + Tensor(v.bind(5))).sum()
self.assertEqual(t1.item(), 60.0)
cache_size_after_first = len(schedule_cache)
# second run with v=10 should reuse cache
t2 = (x + Tensor(v.bind(10))).sum()
self.assertEqual(t2.item(), 110.0)
self.assertEqual(len(schedule_cache), cache_size_after_first)
def test_bound_variable_var_vals(self):
v = Variable('pos', 1, 100)
x = Tensor.ones(10).contiguous().realize()
t = x + Tensor(v.bind(42))
_, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42})
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()

View File

@@ -20,12 +20,11 @@ class ScheduleItem:
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
with cpu_profile(TracingKey("toposort sched_sink")):
# construct the KERNEL children graph based on assigns
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort():
if u.op is Ops.RANGE:
in_degree.setdefault(u, 0)
@@ -47,11 +46,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
elif s.op is Ops.BIND:
# for RANGE this is in fixedvars
if s.src[1].op is not Ops.RANGE:
var, val = s.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
pass # BIND to RANGE handled in fixedvars, BIND to CONST extracted earlier in complete_create_schedule_with_vars
else:
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
@@ -108,7 +103,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
else:
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
sched_ptr += 1
return real_schedule, var_vals
return real_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
@@ -124,11 +119,15 @@ def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
return ret
def unbind_var(ctx:dict[UOp, UOp], b:UOp):
# tag the DEFINE_VAR to distinguish it from unbound ones in rebind
ctx[b] = ret = b.src[0].rtag(())
return ret
pm_pre_sched_cache = PatternMatcher([
# replace input buffers
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), unbind_var),
])
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
@@ -141,6 +140,7 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
(UPat(Ops.DEFINE_VAR, name="b"), lambda ctx,b: ctx.get(b) if b.tag is not None else None),
])
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@@ -149,9 +149,16 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# big_sink srcs are all the Tensors
st = time.perf_counter()
# replace all UNIQUE buffers with LUNIQUE
# replace all UNIQUE buffers with LUNIQUE, unbind BINDs
input_buffers: dict[UOp, UOp] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
# extract var_vals from BINDs that were unbound (ctx stores BIND -> tagged_DEFINE_VAR)
var_vals: dict[str, int] = {}
for k in input_buffers:
if k.op is Ops.BIND:
name, val = k.src[0].arg[0], k.src[1].arg
assert name not in var_vals or var_vals[name] == val, f"bind mismatch on {k.src[0]}, {var_vals[name]} != {val}"
var_vals[name] = val
sched_cache_key = big_sink_cache.key
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
@@ -187,7 +194,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule
schedule, var_vals = create_schedule_with_vars(big_sink)
schedule = create_schedule(big_sink)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
# remove all AFTERs, after scheduling, the tensors are just buffers

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element
@@ -522,6 +522,11 @@ add_tags = PatternMatcher([
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])
# NOTE: don't remove tag from DEFINE_VAR - schedule cache uses tags to track unbound variables, and ucache would return original untagged UOp
pm_remove_rangeify_tags = PatternMatcher([
(UPat(GroupOp.All-{Ops.DEFINE_VAR}, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists
# modified from kernelize.py to not use ShapeTracker
@@ -580,7 +585,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# TODO: we can probably get this earlier
sink_tags = [s.tag for s in tsink.src]
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
tsink = graph_rewrite(tsink, pm_remove_rangeify_tags, name="remove rangeify tags")
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")

View File

@@ -251,7 +251,7 @@ class Tensor(OpMixin):
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
"""Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst)
assert len(var_vals) == 0
assert len(schedule) == 0 or len(var_vals) == 0
return schedule
@disable_gc()

View File

@@ -1312,7 +1312,6 @@ pm_lower_index_dtype = PatternMatcher([
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind()