add notes about jit to claude.md

This commit is contained in:
George Hotz
2025-12-12 16:48:23 -05:00
parent af86cae10c
commit 429f82e6a9

View File

@@ -34,6 +34,35 @@ result = graph_rewrite(uop, pm)
### Schedule Cache ### Schedule Cache
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache. Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
### TinyJit Behavior
TinyJit captures a schedule on the second call (cnt=1) and replays it on subsequent calls. **Critical**: The Python code inside a jitted function only runs during warmup (cnt=0,1). After that, only the captured schedule executes.
**Side effects and assigns**: If a tensor is modified via `.assign()` inside a jitted function but not included in the `realize()` call, those assigns won't be captured in the schedule. This is especially important for:
- **BatchNorm running stats** (`running_mean`, `running_var`) - These are updated via `.assign()` during forward pass but are NOT dependencies of the loss
- Any stateful tensor updated as a side effect
```python
# ❌ BROKEN with JIT - BatchNorm stats only update during warmup (2 iterations)
@TinyJit
def train_step():
loss = model(x).mean().backward()
Tensor.realize(loss, grads) # running_mean.assign() not captured!
# ✅ CORRECT - explicitly realize buffers so assigns are in the schedule
@TinyJit
def train_step():
loss = model(x).mean().backward()
Tensor.realize(*params, *buffers, loss, grads) # buffers includes running stats
```
**Debugging JIT issues**: If training works with `JIT=0` but fails with JIT, check if stateful tensors are being realized. You can verify ASSIGN chains:
```python
def count_assign_chain(uop, depth=0):
if uop.op.name != 'ASSIGN': return depth
return count_assign_chain(uop.src[0], depth+1)
print(count_assign_chain(bn.running_mean.uop)) # Should increase each step, not plateau at 2
```
## Directory Structure ## Directory Structure
``` ```
@@ -84,6 +113,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
- `SPEC=1` - Enable UOp spec verification - `SPEC=1` - Enable UOp spec verification
- `NOOPT=1` - Disable optimizations - `NOOPT=1` - Disable optimizations
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device - `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
- `JIT=0` - Disable JIT (useful for debugging JIT-related issues)
## Debugging Tips ## Debugging Tips
@@ -91,6 +121,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems 2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks 3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]` 4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
5. **JIT vs non-JIT**: If something works with `JIT=0` but not with JIT, the issue is likely unrealized side-effect tensors (see TinyJit Behavior above)
6. **Check tensor state**: `tensor.uop.op` shows current state - `Ops.BUFFER` means realized, `Ops.ASSIGN` means pending write
## Style Notes ## Style Notes