update examples/gradaccum_mnist.py to use the JIT

This commit is contained in:
George Hotz
2025-12-03 16:11:42 -08:00
parent 24ca8eeaa7
commit 96d16675fe

View File

@@ -1,6 +1,6 @@
import itertools import itertools
from typing import Callable from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device from tinygrad import nn, Tensor, dtypes, Device, TinyJit
from tinygrad.helpers import getenv, trange, partition from tinygrad.helpers import getenv, trange, partition
class Model: class Model:
@@ -28,59 +28,81 @@ def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0
if __name__ == "__main__": if __name__ == "__main__":
BS = getenv("BS", 512) BS = getenv("BS", 512)
ACC_STEPS = getenv("ACC_STEPS", 4) ACC_STEPS = getenv("ACC_STEPS", 8)
X_train, Y_train, X_test, Y_test = nn.datasets.mnist() X_train, Y_train, X_test, Y_test = nn.datasets.mnist()
model = Model() model = Model()
params = nn.state.get_parameters(model) params = nn.state.get_parameters(model)
# set requires grad on the ones we need gradients of # init params, set requires grad on the ones we need gradients of
for x in params: for x in params:
if x.requires_grad is None: x.requires_grad_() if x.requires_grad is None: x.requires_grad_()
x.replace(x.contiguous())
Tensor.realize(*params)
# split params (with grads) and buffers (without) # split params (with grads) and buffers (without)
params, buffers = partition(nn.state.get_parameters(model), lambda x: x.requires_grad) params, buffers = partition(params, lambda x: x.requires_grad)
print(f"params: {len(params)} buffers: {len(buffers)}") print(f"params: {len(params)} buffers: {len(buffers)}")
# optim params # optim params
pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0)) pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0))
adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous() adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous() adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False) adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False) adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t] adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t]
# create loss and grads. init all state so the JIT works on microbatch
for x in params: x.assign(x.detach())
loss = Tensor.zeros(tuple()).contiguous()
grads = Tensor.zeros(pos_params[-1]).contiguous()
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit
@Tensor.train() @Tensor.train()
def microbatch(loss: Tensor, grads: Tensor): def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None for t in params: t.grad = None
# divide by ACC_STEPS at the loss # divide by ACC_STEPS at the loss
uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward() uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward()
ugrads = Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0)
for t in params: t.grad = None
# concat the grads and assign them # concat the grads and assign them
loss.assign(loss + uloss) loss.assign(loss + uloss)
grads.assign(grads + Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0)) grads.assign(grads + ugrads)
Tensor.realize(loss, grads) Tensor.realize(*params, *buffers, loss, grads)
@TinyJit
def optimizer():
# run optimizer (on CPU, where adam params live)
delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t)
# update the params, copying back the delta one at a time to avoid OOM
# NOTE: the scheduler is ordering things poorly, all the copies are happening before the adds
for j,tt in enumerate(params):
tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT))
# realize everything, zero out loss and grads
loss.assign(Tensor.zeros_like(loss))
grads.assign(Tensor.zeros_like(grads))
Tensor.realize(*params, *adam_params, loss, grads)
@TinyJit
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan') test_acc = float('nan')
for i in (t:=trange(getenv("STEPS", 70))): for i in (t:=trange(getenv("STEPS", 70))):
# microbatch sets the gradients # microbatch sets the gradients
loss = Tensor.zeros(tuple()).contiguous() for _ in range(ACC_STEPS): microbatch()
grads = Tensor.zeros(pos_params[-1]).contiguous()
for _ in range(ACC_STEPS): microbatch(loss, grads)
# run optimizer (on CPU, where adam params live) # get the loss before the optimizer clears it
delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t) # this is already realized so this isn't a schedule
loss_item = loss.item()
# update the params, copying back the delta one at a time to avoid OOM # run the optimizer
for j,tt in enumerate(params): optimizer()
tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT))
# realize everything
Tensor.realize(*params, *buffers, *adam_params)
# eval # eval
if i%10 == 9: test_acc = get_test_acc().item() if i%10 == 9: test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") t.set_description(f"loss: {loss_item:6.2f} test_accuracy: {test_acc:5.2f}%")