From 44104b0b7f3f915c0a3a9ce4d2b78758d33b9a38 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:27:32 -0800 Subject: [PATCH] mnist with grad acc + Adam on CPU (#13520) * mnist with grad acc + Adam on CPU * still broken, but closer * works w/o jit * this works without the jit --- examples/gradaccum_mnist.py | 73 +++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 examples/gradaccum_mnist.py diff --git a/examples/gradaccum_mnist.py b/examples/gradaccum_mnist.py new file mode 100644 index 0000000000..c46cd0ced3 --- /dev/null +++ b/examples/gradaccum_mnist.py @@ -0,0 +1,73 @@ +import itertools +from examples.beautiful_mnist import Model +from tinygrad import nn, Tensor, dtypes, Device +from tinygrad.helpers import getenv, trange, partition + +# TODO: refactor this into optim/onnx +def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0.001, b1=0.9, b2=0.999, eps=1e-6) -> Tensor: + b1_t *= b1 + b2_t *= b2 + m.assign(b1 * m + (1.0 - b1) * g) + v.assign(b2 * v + (1.0 - b2) * (g * g)) + m_hat = m / (1.0 - b1_t) + v_hat = v / (1.0 - b2_t) + return lr * (m_hat / (v_hat.sqrt() + eps)) + +if __name__ == "__main__": + BS = getenv("BS", 512) + ACC_STEPS = getenv("ACC_STEPS", 4) + + X_train, Y_train, X_test, Y_test = nn.datasets.mnist() + model = Model() + + params = nn.state.get_parameters(model) + + # set requires grad on the ones we need gradients of + for x in params: + if x.requires_grad is None: x.requires_grad_() + + # split params (with grads) and buffers (without) + params, buffers = partition(nn.state.get_parameters(model), lambda x: x.requires_grad) + print(f"params: {len(params)} buffers: {len(buffers)}") + + # optim params + pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0)) + adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous() + adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous() + adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False) + adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False) + adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t] + + @Tensor.train() + def microbatch(loss: Tensor, grads: Tensor): + samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) + for t in params: t.grad = None + # divide by ACC_STEPS at the loss + uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward() + # concat the grads and assign them + loss.assign(loss + uloss) + grads.assign(grads + Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0)) + Tensor.realize(loss, grads) + + def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 + + test_acc = float('nan') + for i in (t:=trange(getenv("STEPS", 70))): + # microbatch sets the gradients + loss = Tensor.zeros(tuple()).contiguous() + grads = Tensor.zeros(pos_params[-1]).contiguous() + for _ in range(ACC_STEPS): microbatch(loss, grads) + + # run optimizer (on CPU, where adam params live) + delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t) + + # update the params, copying back the delta one at a time to avoid OOM + for j,tt in enumerate(params): + tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT)) + + # realize everything + Tensor.realize(*params, *buffers, *adam_params) + + # eval + if i%10 == 9: test_acc = get_test_acc().item() + t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")