mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
schedule: cache unbinds for consistent cache keys (#13662)
* schedule: cache unbinds for consistent cache keys different bound variable values (e.g. kv cache positions) now produce the same schedule cache key by unbinding BIND(DEFINE_VAR, CONST) before computing the cache key and rebinding after lookup. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * schedule: cache unbinds for consistent cache keys When scheduling, BIND(DEFINE_VAR, CONST) nodes are now unbound to tagged DEFINE_VARs before computing the cache key. This ensures that the same computation with different bound values (e.g., different KV cache positions in LLM) gets the same cache key and reuses the cached schedule. The fix: - pm_pre_sched_cache: replaces BIND with tagged DEFINE_VAR - pm_post_sched_cache: restores tagged DEFINE_VAR back to original BIND - pm_remove_rangeify_tags: excludes DEFINE_VAR to preserve tags through rangeify - var_vals extracted from BINDs before cache key computation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * schedule: fix BIND handling and add CLAUDE.md - Handle BIND to RANGE in create_schedule (not matched by CONST pattern) - Assert all BINDs on same variable have same value - Add CLAUDE.md codebase guide 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
130
CLAUDE.md
Normal file
130
CLAUDE.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# Claude Code Guide for tinygrad
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
tinygrad compiles tensor operations into optimized kernels. The pipeline:
|
||||
|
||||
1. **Tensor** (`tensor.py`) - User-facing API, creates UOp graph
|
||||
2. **UOp** (`uop/ops.py`) - Unified IR for all operations (both tensor and kernel level)
|
||||
3. **Schedule** (`engine/schedule.py`, `schedule/`) - Converts tensor UOps to kernel UOps
|
||||
4. **Codegen** (`codegen/`) - Converts kernel UOps to device code
|
||||
5. **Runtime** (`runtime/`) - Device-specific execution
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### UOp (Universal Operation)
|
||||
Everything is a UOp - tensors, operations, buffers, kernels. Key properties:
|
||||
- `op`: The operation type (Ops enum)
|
||||
- `dtype`: Data type
|
||||
- `src`: Tuple of source UOps
|
||||
- `arg`: Operation-specific argument
|
||||
- `tag`: Optional tag for graph transformations
|
||||
|
||||
UOps are **immutable and cached** - creating the same UOp twice returns the same object (ucache).
|
||||
|
||||
### PatternMatcher
|
||||
Used extensively for graph transformations:
|
||||
```python
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.ADD, src=(UPat.cvar("x"), UPat.cvar("x"))), lambda x: x * 2),
|
||||
])
|
||||
result = graph_rewrite(uop, pm)
|
||||
```
|
||||
|
||||
### Schedule Cache
|
||||
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
tinygrad/
|
||||
├── tensor.py # Tensor class, user API
|
||||
├── device.py # Buffer, device management
|
||||
├── dtype.py # Data types
|
||||
├── helpers.py # Utilities, environment vars
|
||||
├── uop/
|
||||
│ ├── ops.py # UOp class, Ops enum, PatternMatcher
|
||||
│ ├── spec.py # UOp type verification
|
||||
│ └── symbolic.py # Symbolic math simplification
|
||||
├── engine/
|
||||
│ ├── schedule.py # Schedule creation, caching
|
||||
│ ├── realize.py # Tensor realization
|
||||
│ ├── jit.py # JIT compilation
|
||||
│ └── memory.py # Memory planning
|
||||
├── schedule/
|
||||
│ ├── rangeify.py # Convert movements to ranges
|
||||
│ └── indexing.py # Index calculations
|
||||
├── codegen/
|
||||
│ ├── kernel.py # Kernel optimization
|
||||
│ └── uopgraph.py # UOp graph transformations
|
||||
├── renderer/ # Code generation (CUDA, Metal, etc.)
|
||||
└── runtime/ # Device backends
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run specific test
|
||||
python -m pytest test/unit/test_schedule_cache.py -xvs
|
||||
|
||||
# Run with timeout
|
||||
python -m pytest test/test_symbolic_ops.py -x --timeout=60
|
||||
|
||||
# Debug with print
|
||||
DEBUG=2 python -m pytest test/test_schedule.py::test_name -xvs
|
||||
|
||||
# Visualize UOp graphs
|
||||
VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
```
|
||||
|
||||
## Common Environment Variables
|
||||
|
||||
- `DEBUG=1-4` - Increasing verbosity
|
||||
- `VIZ=1` - Enable graph visualization
|
||||
- `SPEC=1` - Enable UOp spec verification
|
||||
- `NOOPT=1` - Disable optimizations
|
||||
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())`
|
||||
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
|
||||
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
|
||||
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
|
||||
|
||||
## Style Notes
|
||||
|
||||
- 2-space indentation, 150 char line limit
|
||||
- PatternMatchers should be defined at module level (slow to construct)
|
||||
- Prefer `graph_rewrite` over manual graph traversal
|
||||
- UOp methods like `.replace()` preserve tags unless explicitly changed
|
||||
- Use `.rtag(value)` to add tags to UOps
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Graph Transformation
|
||||
```python
|
||||
def my_transform(ctx, x):
|
||||
# Return new UOp or None to skip
|
||||
return x.replace(arg=new_arg)
|
||||
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.SOMETHING, name="x"), my_transform),
|
||||
])
|
||||
result = graph_rewrite(input_uop, pm, ctx={})
|
||||
```
|
||||
|
||||
### Finding Variables
|
||||
```python
|
||||
# Get all variables in a UOp graph
|
||||
variables = uop.variables()
|
||||
|
||||
# Get bound variable values
|
||||
var, val = bind_uop.unbind()
|
||||
```
|
||||
|
||||
### Shape Handling
|
||||
```python
|
||||
# Shapes can be symbolic (contain UOps)
|
||||
shape = tensor.shape # tuple[sint, ...] where sint = int | UOp
|
||||
```
|
||||
@@ -1,8 +1,31 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, Variable
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
|
||||
class TestScheduleCache(unittest.TestCase):
|
||||
def test_bound_variable_reuses_cache(self):
|
||||
schedule_cache.clear()
|
||||
v = Variable('v', 1, 100)
|
||||
x = Tensor.ones(10).contiguous().realize()
|
||||
|
||||
# first run with v=5
|
||||
t1 = (x + Tensor(v.bind(5))).sum()
|
||||
self.assertEqual(t1.item(), 60.0)
|
||||
cache_size_after_first = len(schedule_cache)
|
||||
|
||||
# second run with v=10 should reuse cache
|
||||
t2 = (x + Tensor(v.bind(10))).sum()
|
||||
self.assertEqual(t2.item(), 110.0)
|
||||
self.assertEqual(len(schedule_cache), cache_size_after_first)
|
||||
|
||||
def test_bound_variable_var_vals(self):
|
||||
v = Variable('pos', 1, 100)
|
||||
x = Tensor.ones(10).contiguous().realize()
|
||||
|
||||
t = x + Tensor(v.bind(42))
|
||||
_, var_vals = t.schedule_with_vars()
|
||||
self.assertEqual(var_vals, {'pos': 42})
|
||||
|
||||
def test_simple(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
b = Tensor.ones(10).contiguous()
|
||||
|
||||
@@ -20,12 +20,11 @@ class ScheduleItem:
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is Ops.RANGE:
|
||||
in_degree.setdefault(u, 0)
|
||||
@@ -47,11 +46,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
# for RANGE this is in fixedvars
|
||||
if s.src[1].op is not Ops.RANGE:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
pass # BIND to RANGE handled in fixedvars, BIND to CONST extracted earlier in complete_create_schedule_with_vars
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
|
||||
@@ -108,7 +103,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
else:
|
||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||
sched_ptr += 1
|
||||
return real_schedule, var_vals
|
||||
return real_schedule
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
@@ -124,11 +119,15 @@ def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
|
||||
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
|
||||
return ret
|
||||
|
||||
def unbind_var(ctx:dict[UOp, UOp], b:UOp):
|
||||
# tag the DEFINE_VAR to distinguish it from unbound ones in rebind
|
||||
ctx[b] = ret = b.src[0].rtag(())
|
||||
return ret
|
||||
|
||||
pm_pre_sched_cache = PatternMatcher([
|
||||
# replace input buffers
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# remove unique consts
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), unbind_var),
|
||||
])
|
||||
|
||||
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
||||
@@ -141,6 +140,7 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
|
||||
(UPat(Ops.DEFINE_VAR, name="b"), lambda ctx,b: ctx.get(b) if b.tag is not None else None),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
|
||||
@@ -149,9 +149,16 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
# replace all UNIQUE buffers with LUNIQUE
|
||||
# replace all UNIQUE buffers with LUNIQUE, unbind BINDs
|
||||
input_buffers: dict[UOp, UOp] = {}
|
||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
|
||||
# extract var_vals from BINDs that were unbound (ctx stores BIND -> tagged_DEFINE_VAR)
|
||||
var_vals: dict[str, int] = {}
|
||||
for k in input_buffers:
|
||||
if k.op is Ops.BIND:
|
||||
name, val = k.src[0].arg[0], k.src[1].arg
|
||||
assert name not in var_vals or var_vals[name] == val, f"bind mismatch on {k.src[0]}, {var_vals[name]} != {val}"
|
||||
var_vals[name] = val
|
||||
sched_cache_key = big_sink_cache.key
|
||||
|
||||
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
@@ -187,7 +194,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
|
||||
# create the schedule
|
||||
schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||
schedule = create_schedule(big_sink)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
@@ -522,6 +522,11 @@ add_tags = PatternMatcher([
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
||||
# NOTE: don't remove tag from DEFINE_VAR - schedule cache uses tags to track unbound variables, and ucache would return original untagged UOp
|
||||
pm_remove_rangeify_tags = PatternMatcher([
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_VAR}, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None),
|
||||
])
|
||||
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
# modified from kernelize.py to not use ShapeTracker
|
||||
|
||||
@@ -580,7 +585,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
# TODO: we can probably get this earlier
|
||||
sink_tags = [s.tag for s in tsink.src]
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
tsink = graph_rewrite(tsink, pm_remove_rangeify_tags, name="remove rangeify tags")
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ class Tensor(OpMixin):
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
||||
assert len(var_vals) == 0
|
||||
assert len(schedule) == 0 or len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
@disable_gc()
|
||||
|
||||
@@ -1312,7 +1312,6 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||
v,i = x.unbind()
|
||||
|
||||
Reference in New Issue
Block a user