mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
add notes about jit to claude.md
This commit is contained in:
32
CLAUDE.md
32
CLAUDE.md
@@ -34,6 +34,35 @@ result = graph_rewrite(uop, pm)
|
||||
### Schedule Cache
|
||||
Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.
|
||||
|
||||
### TinyJit Behavior
|
||||
TinyJit captures a schedule on the second call (cnt=1) and replays it on subsequent calls. **Critical**: The Python code inside a jitted function only runs during warmup (cnt=0,1). After that, only the captured schedule executes.
|
||||
|
||||
**Side effects and assigns**: If a tensor is modified via `.assign()` inside a jitted function but not included in the `realize()` call, those assigns won't be captured in the schedule. This is especially important for:
|
||||
- **BatchNorm running stats** (`running_mean`, `running_var`) - These are updated via `.assign()` during forward pass but are NOT dependencies of the loss
|
||||
- Any stateful tensor updated as a side effect
|
||||
|
||||
```python
|
||||
# ❌ BROKEN with JIT - BatchNorm stats only update during warmup (2 iterations)
|
||||
@TinyJit
|
||||
def train_step():
|
||||
loss = model(x).mean().backward()
|
||||
Tensor.realize(loss, grads) # running_mean.assign() not captured!
|
||||
|
||||
# ✅ CORRECT - explicitly realize buffers so assigns are in the schedule
|
||||
@TinyJit
|
||||
def train_step():
|
||||
loss = model(x).mean().backward()
|
||||
Tensor.realize(*params, *buffers, loss, grads) # buffers includes running stats
|
||||
```
|
||||
|
||||
**Debugging JIT issues**: If training works with `JIT=0` but fails with JIT, check if stateful tensors are being realized. You can verify ASSIGN chains:
|
||||
```python
|
||||
def count_assign_chain(uop, depth=0):
|
||||
if uop.op.name != 'ASSIGN': return depth
|
||||
return count_assign_chain(uop.src[0], depth+1)
|
||||
print(count_assign_chain(bn.running_mean.uop)) # Should increase each step, not plateau at 2
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
@@ -84,6 +113,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
- `SPEC=1` - Enable UOp spec verification
|
||||
- `NOOPT=1` - Disable optimizations
|
||||
- `DEVICE=CPU/CUDA/AMD/METAL` - Set default device
|
||||
- `JIT=0` - Disable JIT (useful for debugging JIT-related issues)
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
@@ -91,6 +121,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
|
||||
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
|
||||
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
|
||||
5. **JIT vs non-JIT**: If something works with `JIT=0` but not with JIT, the issue is likely unrealized side-effect tensors (see TinyJit Behavior above)
|
||||
6. **Check tensor state**: `tensor.uop.op` shows current state - `Ops.BUFFER` means realized, `Ops.ASSIGN` means pending write
|
||||
|
||||
## Style Notes
|
||||
|
||||
|
||||
Reference in New Issue
Block a user