From 96d16675fe0861e87565d15572e2b0077f9e36e3 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 3 Dec 2025 16:11:42 -0800 Subject: [PATCH] update examples/gradaccum_mnist.py to use the JIT --- examples/gradaccum_mnist.py | 64 +++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/examples/gradaccum_mnist.py b/examples/gradaccum_mnist.py index 376fc5785a..a660afddf4 100644 --- a/examples/gradaccum_mnist.py +++ b/examples/gradaccum_mnist.py @@ -1,6 +1,6 @@ import itertools from typing import Callable -from tinygrad import nn, Tensor, dtypes, Device +from tinygrad import nn, Tensor, dtypes, Device, TinyJit from tinygrad.helpers import getenv, trange, partition class Model: @@ -28,59 +28,81 @@ def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0 if __name__ == "__main__": BS = getenv("BS", 512) - ACC_STEPS = getenv("ACC_STEPS", 4) + ACC_STEPS = getenv("ACC_STEPS", 8) X_train, Y_train, X_test, Y_test = nn.datasets.mnist() model = Model() params = nn.state.get_parameters(model) - # set requires grad on the ones we need gradients of + # init params, set requires grad on the ones we need gradients of for x in params: if x.requires_grad is None: x.requires_grad_() + x.replace(x.contiguous()) + Tensor.realize(*params) # split params (with grads) and buffers (without) - params, buffers = partition(nn.state.get_parameters(model), lambda x: x.requires_grad) + params, buffers = partition(params, lambda x: x.requires_grad) print(f"params: {len(params)} buffers: {len(buffers)}") # optim params pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0)) adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous() adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous() - adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False) - adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False) + adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous() + adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous() adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t] + # create loss and grads. init all state so the JIT works on microbatch + for x in params: x.assign(x.detach()) + loss = Tensor.zeros(tuple()).contiguous() + grads = Tensor.zeros(pos_params[-1]).contiguous() + Tensor.realize(*params, *buffers, *adam_params, loss, grads) + + @TinyJit @Tensor.train() - def microbatch(loss: Tensor, grads: Tensor): + def microbatch(): samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) for t in params: t.grad = None # divide by ACC_STEPS at the loss uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward() + ugrads = Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0) + for t in params: t.grad = None # concat the grads and assign them loss.assign(loss + uloss) - grads.assign(grads + Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0)) - Tensor.realize(loss, grads) + grads.assign(grads + ugrads) + Tensor.realize(*params, *buffers, loss, grads) + @TinyJit + def optimizer(): + # run optimizer (on CPU, where adam params live) + delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t) + + # update the params, copying back the delta one at a time to avoid OOM + # NOTE: the scheduler is ordering things poorly, all the copies are happening before the adds + for j,tt in enumerate(params): + tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT)) + + # realize everything, zero out loss and grads + loss.assign(Tensor.zeros_like(loss)) + grads.assign(Tensor.zeros_like(grads)) + Tensor.realize(*params, *adam_params, loss, grads) + + @TinyJit def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 test_acc = float('nan') for i in (t:=trange(getenv("STEPS", 70))): # microbatch sets the gradients - loss = Tensor.zeros(tuple()).contiguous() - grads = Tensor.zeros(pos_params[-1]).contiguous() - for _ in range(ACC_STEPS): microbatch(loss, grads) + for _ in range(ACC_STEPS): microbatch() - # run optimizer (on CPU, where adam params live) - delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t) + # get the loss before the optimizer clears it + # this is already realized so this isn't a schedule + loss_item = loss.item() - # update the params, copying back the delta one at a time to avoid OOM - for j,tt in enumerate(params): - tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT)) - - # realize everything - Tensor.realize(*params, *buffers, *adam_params) + # run the optimizer + optimizer() # eval if i%10 == 9: test_acc = get_test_acc().item() - t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") + t.set_description(f"loss: {loss_item:6.2f} test_accuracy: {test_acc:5.2f}%")