mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Revert "schedule: cache unbinds for consistent cache keys (#13662)"
This reverts commit af86cae10c.
This commit is contained in:
130
CLAUDE.md
130
CLAUDE.md
@@ -1,130 +0,0 @@
|
|||||||
# Claude Code Guide for tinygrad
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
tinygrad compiles tensor operations into optimized kernels. The pipeline:
|
|
||||||
|
|
||||||
1. **Tensor** (`tensor.py`) - User-facing API, creates UOp graph
|
|
||||||
2. **UOp** (`uop/ops.py`) - Unified IR for all operations (both tensor and kernel level)
|
|
||||||
3. **Schedule** (`engine/schedule.py`, `schedule/`) - Converts tensor UOps to kernel UOps
|
|
||||||
4. **Codegen** (`codegen/`) - Converts kernel UOps to device code
|
|
||||||
5. **Runtime** (`runtime/`) - Device-specific execution
|
|
||||||
|
|
||||||
## Key Concepts
|
|
||||||
|
|
||||||
### UOp (Universal Operation)
|
|
||||||
Everything is a UOp - tensors, operations, buffers, kernels. Key properties:
|
|
||||||
- `op`: The operation type (Ops enum)
|
|
||||||
- `dtype`: Data type
|
|
||||||
- `src`: Tuple of source UOps
|
|
||||||
- `arg`: Operation-specific argument
|
|
||||||
- `tag`: Optional tag for graph transformations
|
|
||||||
|
|
||||||
UOps are **immutable and cached** - creating the same UOp twice returns the same object (ucache).
|
|
||||||
|
|
||||||
### PatternMatcher
|
|
||||||
Used extensively for graph transformations:
|
|
||||||
```python
|
|
||||||
pm = PatternMatcher([
|
|
||||||
(UPat(Ops.ADD, src=(UPat.cvar("x"), UPat.cvar("x"))), lambda x: x * 2),
|
|
||||||
])
|
|
||||||
result = graph_rewrite(uop, pm)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Schedule Cache
|
|
||||||
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
tinygrad/
|
|
||||||
├── tensor.py # Tensor class, user API
|
|
||||||
├── device.py # Buffer, device management
|
|
||||||
├── dtype.py # Data types
|
|
||||||
├── helpers.py # Utilities, environment vars
|
|
||||||
├── uop/
|
|
||||||
│ ├── ops.py # UOp class, Ops enum, PatternMatcher
|
|
||||||
│ ├── spec.py # UOp type verification
|
|
||||||
│ └── symbolic.py # Symbolic math simplification
|
|
||||||
├── engine/
|
|
||||||
│ ├── schedule.py # Schedule creation, caching
|
|
||||||
│ ├── realize.py # Tensor realization
|
|
||||||
│ ├── jit.py # JIT compilation
|
|
||||||
│ └── memory.py # Memory planning
|
|
||||||
├── schedule/
|
|
||||||
│ ├── rangeify.py # Convert movements to ranges
|
|
||||||
│ └── indexing.py # Index calculations
|
|
||||||
├── codegen/
|
|
||||||
│ ├── kernel.py # Kernel optimization
|
|
||||||
│ └── uopgraph.py # UOp graph transformations
|
|
||||||
├── renderer/ # Code generation (CUDA, Metal, etc.)
|
|
||||||
└── runtime/ # Device backends
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run specific test
|
|
||||||
python -m pytest test/unit/test_schedule_cache.py -xvs
|
|
||||||
|
|
||||||
# Run with timeout
|
|
||||||
python -m pytest test/test_symbolic_ops.py -x --timeout=60
|
|
||||||
|
|
||||||
# Debug with print
|
|
||||||
DEBUG=2 python -m pytest test/test_schedule.py::test_name -xvs
|
|
||||||
|
|
||||||
# Visualize UOp graphs
|
|
||||||
VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Common Environment Variables
|
|
||||||
|
|
||||||
- `DEBUG=1-4` - Increasing verbosity
|
|
||||||
- `VIZ=1` - Enable graph visualization
|
|
||||||
- `SPEC=1` - Enable UOp spec verification
|
|
||||||
- `NOOPT=1` - Disable optimizations
|
|
||||||
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
|
|
||||||
|
|
||||||
## Debugging Tips
|
|
||||||
|
|
||||||
1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())`
|
|
||||||
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
|
|
||||||
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
|
|
||||||
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
|
|
||||||
|
|
||||||
## Style Notes
|
|
||||||
|
|
||||||
- 2-space indentation, 150 char line limit
|
|
||||||
- PatternMatchers should be defined at module level (slow to construct)
|
|
||||||
- Prefer `graph_rewrite` over manual graph traversal
|
|
||||||
- UOp methods like `.replace()` preserve tags unless explicitly changed
|
|
||||||
- Use `.rtag(value)` to add tags to UOps
|
|
||||||
|
|
||||||
## Common Patterns
|
|
||||||
|
|
||||||
### Graph Transformation
|
|
||||||
```python
|
|
||||||
def my_transform(ctx, x):
|
|
||||||
# Return new UOp or None to skip
|
|
||||||
return x.replace(arg=new_arg)
|
|
||||||
|
|
||||||
pm = PatternMatcher([
|
|
||||||
(UPat(Ops.SOMETHING, name="x"), my_transform),
|
|
||||||
])
|
|
||||||
result = graph_rewrite(input_uop, pm, ctx={})
|
|
||||||
```
|
|
||||||
|
|
||||||
### Finding Variables
|
|
||||||
```python
|
|
||||||
# Get all variables in a UOp graph
|
|
||||||
variables = uop.variables()
|
|
||||||
|
|
||||||
# Get bound variable values
|
|
||||||
var, val = bind_uop.unbind()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Shape Handling
|
|
||||||
```python
|
|
||||||
# Shapes can be symbolic (contain UOps)
|
|
||||||
shape = tensor.shape # tuple[sint, ...] where sint = int | UOp
|
|
||||||
```
|
|
||||||
@@ -1,31 +1,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, Variable
|
from tinygrad import Tensor
|
||||||
from tinygrad.engine.schedule import schedule_cache
|
from tinygrad.engine.schedule import schedule_cache
|
||||||
|
|
||||||
class TestScheduleCache(unittest.TestCase):
|
class TestScheduleCache(unittest.TestCase):
|
||||||
def test_bound_variable_reuses_cache(self):
|
|
||||||
schedule_cache.clear()
|
|
||||||
v = Variable('v', 1, 100)
|
|
||||||
x = Tensor.ones(10).contiguous().realize()
|
|
||||||
|
|
||||||
# first run with v=5
|
|
||||||
t1 = (x + Tensor(v.bind(5))).sum()
|
|
||||||
self.assertEqual(t1.item(), 60.0)
|
|
||||||
cache_size_after_first = len(schedule_cache)
|
|
||||||
|
|
||||||
# second run with v=10 should reuse cache
|
|
||||||
t2 = (x + Tensor(v.bind(10))).sum()
|
|
||||||
self.assertEqual(t2.item(), 110.0)
|
|
||||||
self.assertEqual(len(schedule_cache), cache_size_after_first)
|
|
||||||
|
|
||||||
def test_bound_variable_var_vals(self):
|
|
||||||
v = Variable('pos', 1, 100)
|
|
||||||
x = Tensor.ones(10).contiguous().realize()
|
|
||||||
|
|
||||||
t = x + Tensor(v.bind(42))
|
|
||||||
_, var_vals = t.schedule_with_vars()
|
|
||||||
self.assertEqual(var_vals, {'pos': 42})
|
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
a = Tensor.ones(10).contiguous()
|
a = Tensor.ones(10).contiguous()
|
||||||
b = Tensor.ones(10).contiguous()
|
b = Tensor.ones(10).contiguous()
|
||||||
|
|||||||
@@ -20,11 +20,12 @@ class ScheduleItem:
|
|||||||
|
|
||||||
# **** schedule linearizer
|
# **** schedule linearizer
|
||||||
|
|
||||||
def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
|
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||||
# construct the KERNEL children graph based on assigns
|
# construct the KERNEL children graph based on assigns
|
||||||
children: dict[UOp, list[UOp]] = {}
|
children: dict[UOp, list[UOp]] = {}
|
||||||
in_degree: dict[UOp, int] = {}
|
in_degree: dict[UOp, int] = {}
|
||||||
|
var_vals: dict[str, int] = {}
|
||||||
for u in sched_sink.toposort():
|
for u in sched_sink.toposort():
|
||||||
if u.op is Ops.RANGE:
|
if u.op is Ops.RANGE:
|
||||||
in_degree.setdefault(u, 0)
|
in_degree.setdefault(u, 0)
|
||||||
@@ -46,7 +47,11 @@ def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
|
|||||||
elif s.op is Ops.BUFFER:
|
elif s.op is Ops.BUFFER:
|
||||||
pass # a BUFFER is already realized, nothing to do here
|
pass # a BUFFER is already realized, nothing to do here
|
||||||
elif s.op is Ops.BIND:
|
elif s.op is Ops.BIND:
|
||||||
pass # BIND to RANGE handled in fixedvars, BIND to CONST extracted earlier in complete_create_schedule_with_vars
|
# for RANGE this is in fixedvars
|
||||||
|
if s.src[1].op is not Ops.RANGE:
|
||||||
|
var, val = s.unbind()
|
||||||
|
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||||
|
var_vals[var.expr] = val
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||||
|
|
||||||
@@ -103,7 +108,7 @@ def create_schedule(sched_sink:UOp) -> list[ScheduleItem]:
|
|||||||
else:
|
else:
|
||||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||||
sched_ptr += 1
|
sched_ptr += 1
|
||||||
return real_schedule
|
return real_schedule, var_vals
|
||||||
|
|
||||||
from tinygrad.engine.memory import memory_planner
|
from tinygrad.engine.memory import memory_planner
|
||||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||||
@@ -119,15 +124,11 @@ def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
|
|||||||
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
|
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def unbind_var(ctx:dict[UOp, UOp], b:UOp):
|
|
||||||
# tag the DEFINE_VAR to distinguish it from unbound ones in rebind
|
|
||||||
ctx[b] = ret = b.src[0].rtag(())
|
|
||||||
return ret
|
|
||||||
|
|
||||||
pm_pre_sched_cache = PatternMatcher([
|
pm_pre_sched_cache = PatternMatcher([
|
||||||
|
# replace input buffers
|
||||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||||
|
# remove unique consts
|
||||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
|
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
|
||||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), unbind_var),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
||||||
@@ -140,7 +141,6 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
|||||||
pm_post_sched_cache = PatternMatcher([
|
pm_post_sched_cache = PatternMatcher([
|
||||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
|
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
|
||||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
|
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
|
||||||
(UPat(Ops.DEFINE_VAR, name="b"), lambda ctx,b: ctx.get(b) if b.tag is not None else None),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
|
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
|
||||||
@@ -149,16 +149,9 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||||||
# big_sink srcs are all the Tensors
|
# big_sink srcs are all the Tensors
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
# replace all UNIQUE buffers with LUNIQUE, unbind BINDs
|
# replace all UNIQUE buffers with LUNIQUE
|
||||||
input_buffers: dict[UOp, UOp] = {}
|
input_buffers: dict[UOp, UOp] = {}
|
||||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
|
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
|
||||||
# extract var_vals from BINDs that were unbound (ctx stores BIND -> tagged_DEFINE_VAR)
|
|
||||||
var_vals: dict[str, int] = {}
|
|
||||||
for k in input_buffers:
|
|
||||||
if k.op is Ops.BIND:
|
|
||||||
name, val = k.src[0].arg[0], k.src[1].arg
|
|
||||||
assert name not in var_vals or var_vals[name] == val, f"bind mismatch on {k.src[0]}, {var_vals[name]} != {val}"
|
|
||||||
var_vals[name] = val
|
|
||||||
sched_cache_key = big_sink_cache.key
|
sched_cache_key = big_sink_cache.key
|
||||||
|
|
||||||
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||||
@@ -194,7 +187,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||||
|
|
||||||
# create the schedule
|
# create the schedule
|
||||||
schedule = create_schedule(big_sink)
|
schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||||
|
|
||||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|||||||
import itertools
|
import itertools
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, range_str
|
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||||
from tinygrad.uop.symbolic import symbolic
|
from tinygrad.uop.symbolic import symbolic
|
||||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||||
@@ -522,11 +522,6 @@ add_tags = PatternMatcher([
|
|||||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||||
])
|
])
|
||||||
|
|
||||||
# NOTE: don't remove tag from DEFINE_VAR - schedule cache uses tags to track unbound variables, and ucache would return original untagged UOp
|
|
||||||
pm_remove_rangeify_tags = PatternMatcher([
|
|
||||||
(UPat(GroupOp.All-{Ops.DEFINE_VAR}, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None),
|
|
||||||
])
|
|
||||||
|
|
||||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||||
# modified from kernelize.py to not use ShapeTracker
|
# modified from kernelize.py to not use ShapeTracker
|
||||||
|
|
||||||
@@ -585,7 +580,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||||||
|
|
||||||
# TODO: we can probably get this earlier
|
# TODO: we can probably get this earlier
|
||||||
sink_tags = [s.tag for s in tsink.src]
|
sink_tags = [s.tag for s in tsink.src]
|
||||||
tsink = graph_rewrite(tsink, pm_remove_rangeify_tags, name="remove rangeify tags")
|
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||||
|
|
||||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||||
|
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ class Tensor(OpMixin):
|
|||||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
schedule, var_vals = self.schedule_with_vars(*lst)
|
||||||
assert len(schedule) == 0 or len(var_vals) == 0
|
assert len(var_vals) == 0
|
||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
@disable_gc()
|
@disable_gc()
|
||||||
|
|||||||
@@ -1312,6 +1312,7 @@ pm_lower_index_dtype = PatternMatcher([
|
|||||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||||
|
|
||||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||||
|
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||||
|
|
||||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||||
v,i = x.unbind()
|
v,i = x.unbind()
|
||||||
|
|||||||
Reference in New Issue
Block a user