From 429f82e6a933fd72459afad5823523adb6583eae Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 12 Dec 2025 16:48:23 -0500 Subject: [PATCH] add notes about jit to claude.md --- CLAUDE.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 6d792002dc..7767a52146 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -34,6 +34,35 @@ result = graph_rewrite(uop, pm) ### Schedule Cache Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache. +### TinyJit Behavior +TinyJit captures a schedule on the second call (cnt=1) and replays it on subsequent calls. **Critical**: The Python code inside a jitted function only runs during warmup (cnt=0,1). After that, only the captured schedule executes. + +**Side effects and assigns**: If a tensor is modified via `.assign()` inside a jitted function but not included in the `realize()` call, those assigns won't be captured in the schedule. This is especially important for: +- **BatchNorm running stats** (`running_mean`, `running_var`) - These are updated via `.assign()` during forward pass but are NOT dependencies of the loss +- Any stateful tensor updated as a side effect + +```python +# ❌ BROKEN with JIT - BatchNorm stats only update during warmup (2 iterations) +@TinyJit +def train_step(): + loss = model(x).mean().backward() + Tensor.realize(loss, grads) # running_mean.assign() not captured! + +# ✅ CORRECT - explicitly realize buffers so assigns are in the schedule +@TinyJit +def train_step(): + loss = model(x).mean().backward() + Tensor.realize(*params, *buffers, loss, grads) # buffers includes running stats +``` + +**Debugging JIT issues**: If training works with `JIT=0` but fails with JIT, check if stateful tensors are being realized. You can verify ASSIGN chains: +```python +def count_assign_chain(uop, depth=0): + if uop.op.name != 'ASSIGN': return depth + return count_assign_chain(uop.src[0], depth+1) +print(count_assign_chain(bn.running_mean.uop)) # Should increase each step, not plateau at 2 +``` + ## Directory Structure ``` @@ -84,6 +113,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()" - `SPEC=1` - Enable UOp spec verification - `NOOPT=1` - Disable optimizations - `DEVICE=CPU/CUDA/AMD/METAL` - Set default device +- `JIT=0` - Disable JIT (useful for debugging JIT-related issues) ## Debugging Tips @@ -91,6 +121,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()" 2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems 3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks 4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]` +5. **JIT vs non-JIT**: If something works with `JIT=0` but not with JIT, the issue is likely unrealized side-effect tensors (see TinyJit Behavior above) +6. **Check tensor state**: `tensor.uop.op` shows current state - `Ops.BUFFER` means realized, `Ops.ASSIGN` means pending write ## Style Notes