mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
update examples/gradaccum_mnist.py to use the JIT
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from typing import Callable
|
||||
from tinygrad import nn, Tensor, dtypes, Device
|
||||
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
|
||||
from tinygrad.helpers import getenv, trange, partition
|
||||
|
||||
class Model:
|
||||
@@ -28,59 +28,81 @@ def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0
|
||||
|
||||
if __name__ == "__main__":
|
||||
BS = getenv("BS", 512)
|
||||
ACC_STEPS = getenv("ACC_STEPS", 4)
|
||||
ACC_STEPS = getenv("ACC_STEPS", 8)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = nn.datasets.mnist()
|
||||
model = Model()
|
||||
|
||||
params = nn.state.get_parameters(model)
|
||||
|
||||
# set requires grad on the ones we need gradients of
|
||||
# init params, set requires grad on the ones we need gradients of
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad_()
|
||||
x.replace(x.contiguous())
|
||||
Tensor.realize(*params)
|
||||
|
||||
# split params (with grads) and buffers (without)
|
||||
params, buffers = partition(nn.state.get_parameters(model), lambda x: x.requires_grad)
|
||||
params, buffers = partition(params, lambda x: x.requires_grad)
|
||||
print(f"params: {len(params)} buffers: {len(buffers)}")
|
||||
|
||||
# optim params
|
||||
pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0))
|
||||
adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
||||
adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
||||
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False)
|
||||
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False)
|
||||
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
|
||||
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
|
||||
adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t]
|
||||
|
||||
# create loss and grads. init all state so the JIT works on microbatch
|
||||
for x in params: x.assign(x.detach())
|
||||
loss = Tensor.zeros(tuple()).contiguous()
|
||||
grads = Tensor.zeros(pos_params[-1]).contiguous()
|
||||
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def microbatch(loss: Tensor, grads: Tensor):
|
||||
def microbatch():
|
||||
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
||||
for t in params: t.grad = None
|
||||
# divide by ACC_STEPS at the loss
|
||||
uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward()
|
||||
ugrads = Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0)
|
||||
for t in params: t.grad = None
|
||||
# concat the grads and assign them
|
||||
loss.assign(loss + uloss)
|
||||
grads.assign(grads + Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0))
|
||||
Tensor.realize(loss, grads)
|
||||
grads.assign(grads + ugrads)
|
||||
Tensor.realize(*params, *buffers, loss, grads)
|
||||
|
||||
@TinyJit
|
||||
def optimizer():
|
||||
# run optimizer (on CPU, where adam params live)
|
||||
delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t)
|
||||
|
||||
# update the params, copying back the delta one at a time to avoid OOM
|
||||
# NOTE: the scheduler is ordering things poorly, all the copies are happening before the adds
|
||||
for j,tt in enumerate(params):
|
||||
tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT))
|
||||
|
||||
# realize everything, zero out loss and grads
|
||||
loss.assign(Tensor.zeros_like(loss))
|
||||
grads.assign(Tensor.zeros_like(grads))
|
||||
Tensor.realize(*params, *adam_params, loss, grads)
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(getenv("STEPS", 70))):
|
||||
# microbatch sets the gradients
|
||||
loss = Tensor.zeros(tuple()).contiguous()
|
||||
grads = Tensor.zeros(pos_params[-1]).contiguous()
|
||||
for _ in range(ACC_STEPS): microbatch(loss, grads)
|
||||
for _ in range(ACC_STEPS): microbatch()
|
||||
|
||||
# run optimizer (on CPU, where adam params live)
|
||||
delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t)
|
||||
# get the loss before the optimizer clears it
|
||||
# this is already realized so this isn't a schedule
|
||||
loss_item = loss.item()
|
||||
|
||||
# update the params, copying back the delta one at a time to avoid OOM
|
||||
for j,tt in enumerate(params):
|
||||
tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT))
|
||||
|
||||
# realize everything
|
||||
Tensor.realize(*params, *buffers, *adam_params)
|
||||
# run the optimizer
|
||||
optimizer()
|
||||
|
||||
# eval
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
t.set_description(f"loss: {loss_item:6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
|
||||
Reference in New Issue
Block a user