From 8c87a0bf8de0b0a5ee66734c07b802d531aeb92b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 12 Dec 2025 16:49:50 -0500 Subject: [PATCH] Revert "schedule: cache unbinds for consistent cache keys (#13662)" This reverts commit af86cae10c2cc656c3ed3e5400f0a49e96619c52. --- CLAUDE.md | 130 ------------------------------- test/unit/test_schedule_cache.py | 25 +----- tinygrad/engine/schedule.py | 31 +++----- tinygrad/schedule/rangeify.py | 9 +-- tinygrad/tensor.py | 2 +- tinygrad/uop/ops.py | 1 + 6 files changed, 17 insertions(+), 181 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 6d792002dc..0000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,130 +0,0 @@ -# Claude Code Guide for tinygrad - -## Architecture Overview - -tinygrad compiles tensor operations into optimized kernels. The pipeline: - -1. **Tensor** (`tensor.py`) - User-facing API, creates UOp graph -2. **UOp** (`uop/ops.py`) - Unified IR for all operations (both tensor and kernel level) -3. **Schedule** (`engine/schedule.py`, `schedule/`) - Converts tensor UOps to kernel UOps -4. **Codegen** (`codegen/`) - Converts kernel UOps to device code -5. **Runtime** (`runtime/`) - Device-specific execution - -## Key Concepts - -### UOp (Universal Operation) -Everything is a UOp - tensors, operations, buffers, kernels. Key properties: -- `op`: The operation type (Ops enum) -- `dtype`: Data type -- `src`: Tuple of source UOps -- `arg`: Operation-specific argument -- `tag`: Optional tag for graph transformations - -UOps are **immutable and cached** - creating the same UOp twice returns the same object (ucache). - -### PatternMatcher -Used extensively for graph transformations: -```python -pm = PatternMatcher([ - (UPat(Ops.ADD, src=(UPat.cvar("x"), UPat.cvar("x"))), lambda x: x * 2), -]) -result = graph_rewrite(uop, pm) -``` - -### Schedule Cache -Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache. - -## Directory Structure - -``` -tinygrad/ -├── tensor.py # Tensor class, user API -├── device.py # Buffer, device management -├── dtype.py # Data types -├── helpers.py # Utilities, environment vars -├── uop/ -│ ├── ops.py # UOp class, Ops enum, PatternMatcher -│ ├── spec.py # UOp type verification -│ └── symbolic.py # Symbolic math simplification -├── engine/ -│ ├── schedule.py # Schedule creation, caching -│ ├── realize.py # Tensor realization -│ ├── jit.py # JIT compilation -│ └── memory.py # Memory planning -├── schedule/ -│ ├── rangeify.py # Convert movements to ranges -│ └── indexing.py # Index calculations -├── codegen/ -│ ├── kernel.py # Kernel optimization -│ └── uopgraph.py # UOp graph transformations -├── renderer/ # Code generation (CUDA, Metal, etc.) -└── runtime/ # Device backends -``` - -## Testing - -```bash -# Run specific test -python -m pytest test/unit/test_schedule_cache.py -xvs - -# Run with timeout -python -m pytest test/test_symbolic_ops.py -x --timeout=60 - -# Debug with print -DEBUG=2 python -m pytest test/test_schedule.py::test_name -xvs - -# Visualize UOp graphs -VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()" -``` - -## Common Environment Variables - -- `DEBUG=1-4` - Increasing verbosity -- `VIZ=1` - Enable graph visualization -- `SPEC=1` - Enable UOp spec verification -- `NOOPT=1` - Disable optimizations -- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device - -## Debugging Tips - -1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())` -2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems -3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks -4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]` - -## Style Notes - -- 2-space indentation, 150 char line limit -- PatternMatchers should be defined at module level (slow to construct) -- Prefer `graph_rewrite` over manual graph traversal -- UOp methods like `.replace()` preserve tags unless explicitly changed -- Use `.rtag(value)` to add tags to UOps - -## Common Patterns - -### Graph Transformation -```python -def my_transform(ctx, x): - # Return new UOp or None to skip - return x.replace(arg=new_arg) - -pm = PatternMatcher([ - (UPat(Ops.SOMETHING, name="x"), my_transform), -]) -result = graph_rewrite(input_uop, pm, ctx={}) -``` - -### Finding Variables -```python -# Get all variables in a UOp graph -variables = uop.variables() - -# Get bound variable values -var, val = bind_uop.unbind() -``` - -### Shape Handling -```python -# Shapes can be symbolic (contain UOps) -shape = tensor.shape # tuple[sint, ...] where sint = int | UOp -``` diff --git a/test/unit/test_schedule_cache.py b/test/unit/test_schedule_cache.py index 368a517d73..8fc12eabce 100644 --- a/test/unit/test_schedule_cache.py +++ b/test/unit/test_schedule_cache.py @@ -1,31 +1,8 @@ import unittest -from tinygrad import Tensor, Variable +from tinygrad import Tensor from tinygrad.engine.schedule import schedule_cache class TestScheduleCache(unittest.TestCase): - def test_bound_variable_reuses_cache(self): - schedule_cache.clear() - v = Variable('v', 1, 100) - x = Tensor.ones(10).contiguous().realize() - - # first run with v=5 - t1 = (x + Tensor(v.bind(5))).sum() - self.assertEqual(t1.item(), 60.0) - cache_size_after_first = len(schedule_cache) - - # second run with v=10 should reuse cache - t2 = (x + Tensor(v.bind(10))).sum() - self.assertEqual(t2.item(), 110.0) - self.assertEqual(len(schedule_cache), cache_size_after_first) - - def test_bound_variable_var_vals(self): - v = Variable('pos', 1, 100) - x = Tensor.ones(10).contiguous().realize() - - t = x + Tensor(v.bind(42)) - _, var_vals = t.schedule_with_vars() - self.assertEqual(var_vals, {'pos': 42}) - def test_simple(self): a = Tensor.ones(10).contiguous() b = Tensor.ones(10).contiguous() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d31ca67624..f68c0689b6 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -20,11 +20,12 @@ class ScheduleItem: # **** schedule linearizer -def create_schedule(sched_sink:UOp) -> list[ScheduleItem]: +def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]: with cpu_profile(TracingKey("toposort sched_sink")): # construct the KERNEL children graph based on assigns children: dict[UOp, list[UOp]] = {} in_degree: dict[UOp, int] = {} + var_vals: dict[str, int] = {} for u in sched_sink.toposort(): if u.op is Ops.RANGE: in_degree.setdefault(u, 0) @@ -46,7 +47,11 @@ def create_schedule(sched_sink:UOp) -> list[ScheduleItem]: elif s.op is Ops.BUFFER: pass # a BUFFER is already realized, nothing to do here elif s.op is Ops.BIND: - pass # BIND to RANGE handled in fixedvars, BIND to CONST extracted earlier in complete_create_schedule_with_vars + # for RANGE this is in fixedvars + if s.src[1].op is not Ops.RANGE: + var, val = s.unbind() + assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" + var_vals[var.expr] = val else: raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}") @@ -103,7 +108,7 @@ def create_schedule(sched_sink:UOp) -> list[ScheduleItem]: else: real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=())) sched_ptr += 1 - return real_schedule + return real_schedule, var_vals from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_rangeify_map @@ -119,15 +124,11 @@ def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp): ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx)))) return ret -def unbind_var(ctx:dict[UOp, UOp], b:UOp): - # tag the DEFINE_VAR to distinguish it from unbound ones in rebind - ctx[b] = ret = b.src[0].rtag(()) - return ret - pm_pre_sched_cache = PatternMatcher([ + # replace input buffers (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer), + # remove unique consts (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer), - (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), unbind_var), ]) def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp): @@ -140,7 +141,6 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp): pm_post_sched_cache = PatternMatcher([ (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back), (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back), - (UPat(Ops.DEFINE_VAR, name="b"), lambda ctx,b: ctx.get(b) if b.tag is not None else None), ]) schedule_cache: dict[bytes, tuple[UOp, UOp]] = {} @@ -149,16 +149,9 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li # big_sink srcs are all the Tensors st = time.perf_counter() - # replace all UNIQUE buffers with LUNIQUE, unbind BINDs + # replace all UNIQUE buffers with LUNIQUE input_buffers: dict[UOp, UOp] = {} big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache") - # extract var_vals from BINDs that were unbound (ctx stores BIND -> tagged_DEFINE_VAR) - var_vals: dict[str, int] = {} - for k in input_buffers: - if k.op is Ops.BIND: - name, val = k.src[0].arg[0], k.src[1].arg - assert name not in var_vals or var_vals[name] == val, f"bind mismatch on {k.src[0]}, {var_vals[name]} != {val}" - var_vals[name] = val sched_cache_key = big_sink_cache.key if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: @@ -194,7 +187,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)} # create the schedule - schedule = create_schedule(big_sink) + schedule, var_vals = create_schedule_with_vars(big_sink) with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) # remove all AFTERs, after scheduling, the tensors are just buffers diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index de48b67a32..69b74f0446 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo -from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, range_str +from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import PCONTIG, partition, get_single_element @@ -522,11 +522,6 @@ add_tags = PatternMatcher([ (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), ]) -# NOTE: don't remove tag from DEFINE_VAR - schedule cache uses tags to track unbound variables, and ucache would return original untagged UOp -pm_remove_rangeify_tags = PatternMatcher([ - (UPat(GroupOp.All-{Ops.DEFINE_VAR}, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None), -]) - # support for using a contiguous permuted view instead of the parent view if one exists # modified from kernelize.py to not use ShapeTracker @@ -585,7 +580,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # TODO: we can probably get this earlier sink_tags = [s.tag for s in tsink.src] - tsink = graph_rewrite(tsink, pm_remove_rangeify_tags, name="remove rangeify tags") + tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags") if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 44241b0d17..da7b6d15ca 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -251,7 +251,7 @@ class Tensor(OpMixin): def schedule(self, *lst:Tensor) -> list[ScheduleItem]: """Creates the schedule needed to realize these Tensor(s).""" schedule, var_vals = self.schedule_with_vars(*lst) - assert len(schedule) == 0 or len(var_vals) == 0 + assert len(var_vals) == 0 return schedule @disable_gc() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5cc66bede4..2171c87b4f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1312,6 +1312,7 @@ pm_lower_index_dtype = PatternMatcher([ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) +_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) def do_unbind(ctx:dict[Variable, int], x:UOp): v,i = x.unbind()