Revert "schedule: cache unbinds for consistent cache keys (#13662)"

This reverts commit af86cae10c.
This commit is contained in:
George Hotz
2025-12-12 16:49:50 -05:00
parent 443b7fea80
commit 8c87a0bf8d
6 changed files with 17 additions and 181 deletions

130
CLAUDE.md
View File

@@ -1,130 +0,0 @@
# Claude Code Guide for tinygrad
## Architecture Overview
tinygrad compiles tensor operations into optimized kernels. The pipeline:
1. **Tensor** (`tensor.py`) - User-facing API, creates UOp graph
2. **UOp** (`uop/ops.py`) - Unified IR for all operations (both tensor and kernel level)
3. **Schedule** (`engine/schedule.py`, `schedule/`) - Converts tensor UOps to kernel UOps
4. **Codegen** (`codegen/`) - Converts kernel UOps to device code
5. **Runtime** (`runtime/`) - Device-specific execution
## Key Concepts
### UOp (Universal Operation)
Everything is a UOp - tensors, operations, buffers, kernels. Key properties:
- `op`: The operation type (Ops enum)
- `dtype`: Data type
- `src`: Tuple of source UOps
- `arg`: Operation-specific argument
- `tag`: Optional tag for graph transformations
UOps are **immutable and cached** - creating the same UOp twice returns the same object (ucache).
### PatternMatcher
Used extensively for graph transformations:
```python
pm = PatternMatcher([
(UPat(Ops.ADD, src=(UPat.cvar("x"), UPat.cvar("x"))), lambda x: x * 2),
])
result = graph_rewrite(uop, pm)
```
### Schedule Cache
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
## Directory Structure
```
tinygrad/
├── tensor.py # Tensor class, user API
├── device.py # Buffer, device management
├── dtype.py # Data types
├── helpers.py # Utilities, environment vars
├── uop/
│ ├── ops.py # UOp class, Ops enum, PatternMatcher
│ ├── spec.py # UOp type verification
│ └── symbolic.py # Symbolic math simplification
├── engine/
│ ├── schedule.py # Schedule creation, caching
│ ├── realize.py # Tensor realization
│ ├── jit.py # JIT compilation
│ └── memory.py # Memory planning
├── schedule/
│ ├── rangeify.py # Convert movements to ranges
│ └── indexing.py # Index calculations
├── codegen/
│ ├── kernel.py # Kernel optimization
│ └── uopgraph.py # UOp graph transformations
├── renderer/ # Code generation (CUDA, Metal, etc.)
└── runtime/ # Device backends
```
## Testing
```bash
# Run specific test
python -m pytest test/unit/test_schedule_cache.py -xvs
# Run with timeout
python -m pytest test/test_symbolic_ops.py -x --timeout=60
# Debug with print
DEBUG=2 python -m pytest test/test_schedule.py::test_name -xvs
# Visualize UOp graphs
VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
```
## Common Environment Variables
- `DEBUG=1-4` - Increasing verbosity
- `VIZ=1` - Enable graph visualization
- `SPEC=1` - Enable UOp spec verification
- `NOOPT=1` - Disable optimizations
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
## Debugging Tips
1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())`
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
## Style Notes
- 2-space indentation, 150 char line limit
- PatternMatchers should be defined at module level (slow to construct)
- Prefer `graph_rewrite` over manual graph traversal
- UOp methods like `.replace()` preserve tags unless explicitly changed
- Use `.rtag(value)` to add tags to UOps
## Common Patterns
### Graph Transformation
```python
def my_transform(ctx, x):
# Return new UOp or None to skip
return x.replace(arg=new_arg)
pm = PatternMatcher([
(UPat(Ops.SOMETHING, name="x"), my_transform),
])
result = graph_rewrite(input_uop, pm, ctx={})
```
### Finding Variables
```python
# Get all variables in a UOp graph
variables = uop.variables()
# Get bound variable values
var, val = bind_uop.unbind()
```
### Shape Handling
```python
# Shapes can be symbolic (contain UOps)
shape = tensor.shape # tuple[sint, ...] where sint = int | UOp
```

View File

@@ -1,31 +1,8 @@
import unittest import unittest
from tinygrad import Tensor, Variable from tinygrad import Tensor
from tinygrad.engine.schedule import schedule_cache from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase): class TestScheduleCache(unittest.TestCase):
def test_bound_variable_reuses_cache(self):
schedule_cache.clear()
v = Variable('v', 1, 100)
x = Tensor.ones(10).contiguous().realize()
# first run with v=5
t1 = (x + Tensor(v.bind(5))).sum()
self.assertEqual(t1.item(), 60.0)
cache_size_after_first = len(schedule_cache)
# second run with v=10 should reuse cache
t2 = (x + Tensor(v.bind(10))).sum()
self.assertEqual(t2.item(), 110.0)
self.assertEqual(len(schedule_cache), cache_size_after_first)
def test_bound_variable_var_vals(self):
v = Variable('pos', 1, 100)
x = Tensor.ones(10).contiguous().realize()
t = x + Tensor(v.bind(42))
_, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42})
def test_simple(self): def test_simple(self):
a = Tensor.ones(10).contiguous() a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous() b = Tensor.ones(10).contiguous()

View File

@@ -20,11 +20,12 @@ class ScheduleItem:
# **** schedule linearizer # **** schedule linearizer
def create_schedule(sched_sink:UOp) -> list[ScheduleItem]: def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
with cpu_profile(TracingKey("toposort sched_sink")): with cpu_profile(TracingKey("toposort sched_sink")):
# construct the KERNEL children graph based on assigns # construct the KERNEL children graph based on assigns
children: dict[UOp, list[UOp]] = {} children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {} in_degree: dict[UOp, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort(): for u in sched_sink.toposort():
if u.op is Ops.RANGE: if u.op is Ops.RANGE:
in_degree.setdefault(u, 0) in_degree.setdefault(u, 0)
@@ -46,7 +47,11 @@ def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
elif s.op is Ops.BUFFER: elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here pass # a BUFFER is already realized, nothing to do here
elif s.op is Ops.BIND: elif s.op is Ops.BIND:
pass # BIND to RANGE handled in fixedvars, BIND to CONST extracted earlier in complete_create_schedule_with_vars # for RANGE this is in fixedvars
if s.src[1].op is not Ops.RANGE:
var, val = s.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
else: else:
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}") raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
@@ -103,7 +108,7 @@ def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
else: else:
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=())) real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
sched_ptr += 1 sched_ptr += 1
return real_schedule return real_schedule, var_vals
from tinygrad.engine.memory import memory_planner from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
@@ -119,15 +124,11 @@ def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx)))) ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
return ret return ret
def unbind_var(ctx:dict[UOp, UOp], b:UOp):
# tag the DEFINE_VAR to distinguish it from unbound ones in rebind
ctx[b] = ret = b.src[0].rtag(())
return ret
pm_pre_sched_cache = PatternMatcher([ pm_pre_sched_cache = PatternMatcher([
# replace input buffers
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer), (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer), (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), unbind_var),
]) ])
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp): def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
@@ -140,7 +141,6 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
pm_post_sched_cache = PatternMatcher([ pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back), (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back), (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
(UPat(Ops.DEFINE_VAR, name="b"), lambda ctx,b: ctx.get(b) if b.tag is not None else None),
]) ])
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {} schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@@ -149,16 +149,9 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# big_sink srcs are all the Tensors # big_sink srcs are all the Tensors
st = time.perf_counter() st = time.perf_counter()
# replace all UNIQUE buffers with LUNIQUE, unbind BINDs # replace all UNIQUE buffers with LUNIQUE
input_buffers: dict[UOp, UOp] = {} input_buffers: dict[UOp, UOp] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache") big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
# extract var_vals from BINDs that were unbound (ctx stores BIND -> tagged_DEFINE_VAR)
var_vals: dict[str, int] = {}
for k in input_buffers:
if k.op is Ops.BIND:
name, val = k.src[0].arg[0], k.src[1].arg
assert name not in var_vals or var_vals[name] == val, f"bind mismatch on {k.src[0]}, {var_vals[name]} != {val}"
var_vals[name] = val
sched_cache_key = big_sink_cache.key sched_cache_key = big_sink_cache.key
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
@@ -194,7 +187,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)} tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule # create the schedule
schedule = create_schedule(big_sink) schedule, var_vals = create_schedule_with_vars(big_sink)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
# remove all AFTERs, after scheduling, the tensors are just buffers # remove all AFTERs, after scheduling, the tensors are just buffers

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
import itertools import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, range_str from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element from tinygrad.helpers import PCONTIG, partition, get_single_element
@@ -522,11 +522,6 @@ add_tags = PatternMatcher([
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
]) ])
# NOTE: don't remove tag from DEFINE_VAR - schedule cache uses tags to track unbound variables, and ucache would return original untagged UOp
pm_remove_rangeify_tags = PatternMatcher([
(UPat(GroupOp.All-{Ops.DEFINE_VAR}, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists # support for using a contiguous permuted view instead of the parent view if one exists
# modified from kernelize.py to not use ShapeTracker # modified from kernelize.py to not use ShapeTracker
@@ -585,7 +580,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# TODO: we can probably get this earlier # TODO: we can probably get this earlier
sink_tags = [s.tag for s in tsink.src] sink_tags = [s.tag for s in tsink.src]
tsink = graph_rewrite(tsink, pm_remove_rangeify_tags, name="remove rangeify tags") tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")

View File

@@ -251,7 +251,7 @@ class Tensor(OpMixin):
def schedule(self, *lst:Tensor) -> list[ScheduleItem]: def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
"""Creates the schedule needed to realize these Tensor(s).""" """Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst) schedule, var_vals = self.schedule_with_vars(*lst)
assert len(schedule) == 0 or len(var_vals) == 0 assert len(var_vals) == 0
return schedule return schedule
@disable_gc() @disable_gc()

View File

@@ -1312,6 +1312,7 @@ pm_lower_index_dtype = PatternMatcher([
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def do_unbind(ctx:dict[Variable, int], x:UOp): def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind() v,i = x.unbind()