add notes about jit to claude.md

This commit is contained in:
George Hotz
2025-12-12 16:48:23 -05:00
parent af86cae10c
commit 429f82e6a9

View File

@@ -34,6 +34,35 @@ result = graph_rewrite(uop, pm)
### Schedule Cache
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
### TinyJit Behavior
TinyJit captures a schedule on the second call (cnt=1) and replays it on subsequent calls. **Critical**: The Python code inside a jitted function only runs during warmup (cnt=0,1). After that, only the captured schedule executes.
**Side effects and assigns**: If a tensor is modified via `.assign()` inside a jitted function but not included in the `realize()` call, those assigns won't be captured in the schedule. This is especially important for:
- **BatchNorm running stats** (`running_mean`, `running_var`) - These are updated via `.assign()` during forward pass but are NOT dependencies of the loss
- Any stateful tensor updated as a side effect
```python
# ❌ BROKEN with JIT - BatchNorm stats only update during warmup (2 iterations)
@TinyJit
def train_step():
loss = model(x).mean().backward()
Tensor.realize(loss, grads) # running_mean.assign() not captured!
# ✅ CORRECT - explicitly realize buffers so assigns are in the schedule
@TinyJit
def train_step():
loss = model(x).mean().backward()
Tensor.realize(*params, *buffers, loss, grads) # buffers includes running stats
```
**Debugging JIT issues**: If training works with `JIT=0` but fails with JIT, check if stateful tensors are being realized. You can verify ASSIGN chains:
```python
def count_assign_chain(uop, depth=0):
if uop.op.name != 'ASSIGN': return depth
return count_assign_chain(uop.src[0], depth+1)
print(count_assign_chain(bn.running_mean.uop)) # Should increase each step, not plateau at 2
```
## Directory Structure
```
@@ -84,6 +113,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
- `SPEC=1` - Enable UOp spec verification
- `NOOPT=1` - Disable optimizations
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
- `JIT=0` - Disable JIT (useful for debugging JIT-related issues)
## Debugging Tips
@@ -91,6 +121,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
5. **JIT vs non-JIT**: If something works with `JIT=0` but not with JIT, the issue is likely unrealized side-effect tensors (see TinyJit Behavior above)
6. **Check tensor state**: `tensor.uop.op` shows current state - `Ops.BUFFER` means realized, `Ops.ASSIGN` means pending write
## Style Notes