mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
update examples/gradaccum_mnist.py to use the JIT
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from tinygrad import nn, Tensor, dtypes, Device
|
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
|
||||||
from tinygrad.helpers import getenv, trange, partition
|
from tinygrad.helpers import getenv, trange, partition
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
@@ -28,59 +28,81 @@ def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
BS = getenv("BS", 512)
|
BS = getenv("BS", 512)
|
||||||
ACC_STEPS = getenv("ACC_STEPS", 4)
|
ACC_STEPS = getenv("ACC_STEPS", 8)
|
||||||
|
|
||||||
X_train, Y_train, X_test, Y_test = nn.datasets.mnist()
|
X_train, Y_train, X_test, Y_test = nn.datasets.mnist()
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
||||||
params = nn.state.get_parameters(model)
|
params = nn.state.get_parameters(model)
|
||||||
|
|
||||||
# set requires grad on the ones we need gradients of
|
# init params, set requires grad on the ones we need gradients of
|
||||||
for x in params:
|
for x in params:
|
||||||
if x.requires_grad is None: x.requires_grad_()
|
if x.requires_grad is None: x.requires_grad_()
|
||||||
|
x.replace(x.contiguous())
|
||||||
|
Tensor.realize(*params)
|
||||||
|
|
||||||
# split params (with grads) and buffers (without)
|
# split params (with grads) and buffers (without)
|
||||||
params, buffers = partition(nn.state.get_parameters(model), lambda x: x.requires_grad)
|
params, buffers = partition(params, lambda x: x.requires_grad)
|
||||||
print(f"params: {len(params)} buffers: {len(buffers)}")
|
print(f"params: {len(params)} buffers: {len(buffers)}")
|
||||||
|
|
||||||
# optim params
|
# optim params
|
||||||
pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0))
|
pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0))
|
||||||
adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
||||||
adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
||||||
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False)
|
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
|
||||||
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False)
|
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
|
||||||
adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t]
|
adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t]
|
||||||
|
|
||||||
|
# create loss and grads. init all state so the JIT works on microbatch
|
||||||
|
for x in params: x.assign(x.detach())
|
||||||
|
loss = Tensor.zeros(tuple()).contiguous()
|
||||||
|
grads = Tensor.zeros(pos_params[-1]).contiguous()
|
||||||
|
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
|
||||||
|
|
||||||
|
@TinyJit
|
||||||
@Tensor.train()
|
@Tensor.train()
|
||||||
def microbatch(loss: Tensor, grads: Tensor):
|
def microbatch():
|
||||||
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
||||||
for t in params: t.grad = None
|
for t in params: t.grad = None
|
||||||
# divide by ACC_STEPS at the loss
|
# divide by ACC_STEPS at the loss
|
||||||
uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward()
|
uloss = (model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]) / ACC_STEPS).backward()
|
||||||
|
ugrads = Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0)
|
||||||
|
for t in params: t.grad = None
|
||||||
# concat the grads and assign them
|
# concat the grads and assign them
|
||||||
loss.assign(loss + uloss)
|
loss.assign(loss + uloss)
|
||||||
grads.assign(grads + Tensor.cat(*[t.grad.contiguous().flatten() for t in params], dim=0))
|
grads.assign(grads + ugrads)
|
||||||
Tensor.realize(loss, grads)
|
Tensor.realize(*params, *buffers, loss, grads)
|
||||||
|
|
||||||
|
@TinyJit
|
||||||
|
def optimizer():
|
||||||
|
# run optimizer (on CPU, where adam params live)
|
||||||
|
delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t)
|
||||||
|
|
||||||
|
# update the params, copying back the delta one at a time to avoid OOM
|
||||||
|
# NOTE: the scheduler is ordering things poorly, all the copies are happening before the adds
|
||||||
|
for j,tt in enumerate(params):
|
||||||
|
tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT))
|
||||||
|
|
||||||
|
# realize everything, zero out loss and grads
|
||||||
|
loss.assign(Tensor.zeros_like(loss))
|
||||||
|
grads.assign(Tensor.zeros_like(grads))
|
||||||
|
Tensor.realize(*params, *adam_params, loss, grads)
|
||||||
|
|
||||||
|
@TinyJit
|
||||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||||
|
|
||||||
test_acc = float('nan')
|
test_acc = float('nan')
|
||||||
for i in (t:=trange(getenv("STEPS", 70))):
|
for i in (t:=trange(getenv("STEPS", 70))):
|
||||||
# microbatch sets the gradients
|
# microbatch sets the gradients
|
||||||
loss = Tensor.zeros(tuple()).contiguous()
|
for _ in range(ACC_STEPS): microbatch()
|
||||||
grads = Tensor.zeros(pos_params[-1]).contiguous()
|
|
||||||
for _ in range(ACC_STEPS): microbatch(loss, grads)
|
|
||||||
|
|
||||||
# run optimizer (on CPU, where adam params live)
|
# get the loss before the optimizer clears it
|
||||||
delta = functional_adam(grads.to("CPU"), adam_m, adam_v, adam_b1_t, adam_b2_t)
|
# this is already realized so this isn't a schedule
|
||||||
|
loss_item = loss.item()
|
||||||
|
|
||||||
# update the params, copying back the delta one at a time to avoid OOM
|
# run the optimizer
|
||||||
for j,tt in enumerate(params):
|
optimizer()
|
||||||
tt.assign(tt.detach() - delta[pos_params[j]:pos_params[j+1]].reshape(tt.shape).to(Device.DEFAULT))
|
|
||||||
|
|
||||||
# realize everything
|
|
||||||
Tensor.realize(*params, *buffers, *adam_params)
|
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
if i%10 == 9: test_acc = get_test_acc().item()
|
if i%10 == 9: test_acc = get_test_acc().item()
|
||||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
t.set_description(f"loss: {loss_item:6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||||
|
|||||||
Reference in New Issue
Block a user