diff --git a/docs/quickstart.md b/docs/quickstart.md index 17a3b44725..b9887c2154 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -131,12 +131,6 @@ Training neural networks in tinygrad is super simple. All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients. They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py). -First we need to set the training flag in `Tensor`: - -```python -Tensor.training = True -``` - For our loss function we will be using sparse categorical cross entropy loss. ```python @@ -176,37 +170,41 @@ from extra.datasets import fetch_mnist Now we have everything we need to start training our neural network. We will be training for 1000 steps with a batch size of 64. +We use `with Tensor.train()` set the internal flag `Tensor.training` to `True` during training. +Upon exit, the flag is restored to its previous value by the context manager. + ```python X_train, Y_train, X_test, Y_test = fetch_mnist() -for step in range(1000): - # random sample a batch - samp = np.random.randint(0, X_train.shape[0], size=(64)) - batch = Tensor(X_train[samp], requires_grad=False) - # get the corresponding labels - labels = Tensor(Y_train[samp]) +with Tensor.train(): + for step in range(1000): + # random sample a batch + samp = np.random.randint(0, X_train.shape[0], size=(64)) + batch = Tensor(X_train[samp], requires_grad=False) + # get the corresponding labels + labels = Tensor(Y_train[samp]) - # forward pass - out = net(batch) + # forward pass + out = net(batch) - # compute loss - loss = sparse_categorical_crossentropy(out, labels) + # compute loss + loss = sparse_categorical_crossentropy(out, labels) - # zero gradients - opt.zero_grad() + # zero gradients + opt.zero_grad() - # backward pass - loss.backward() + # backward pass + loss.backward() - # update parameters - opt.step() + # update parameters + opt.step() - # calculate accuracy - pred = out.argmax(axis=-1) - acc = (pred == labels).mean() + # calculate accuracy + pred = out.argmax(axis=-1) + acc = (pred == labels).mean() - if step % 100 == 0: - print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}") + if step % 100 == 0: + print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}") ``` ## Evaluation @@ -215,9 +213,6 @@ Now that we have trained our neural network we can evaluate it on the test set. We will be using the same batch size of 64 and will be evaluating for 1000 of those batches. ```python -# set training flag to false -Tensor.training = False - with Timing("Time: "): avg_acc = 0 for step in range(1000): diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 0db852f9fa..2e3774800c 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -22,7 +22,7 @@ from extra.lr_scheduler import OneCycleLR from tinygrad.jit import TinyJit from extra.dist import collectives -BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 100) +BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000) class BatchNorm(nn.BatchNorm2d): def __init__(self, num_features): @@ -234,7 +234,6 @@ def train_cifar(): # this import needs to be done here because this is running in a subprocess from extra.dist import OOB - Tensor.training = True rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1) X_train, Y_train, X_test, Y_test = fetch_cifar() @@ -312,8 +311,7 @@ def train_cifar(): eval_step_jitted = TinyJit(eval_step) eval_step_ema_jitted = TinyJit(eval_step) - # 97 steps in 2 seconds = 20ms / step Tensor.training = True - + # 97 steps in 2 seconds = 20ms / step # step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136 # 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68 # 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1 @@ -327,77 +325,78 @@ def train_cifar(): best_eval = -1 i = 0 batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) - while i <= STEPS: - if i%100 == 0 and i > 1: - # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True - corrects = [] - corrects_ema = [] - losses = [] - losses_ema = [] - for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False): - # further split batch if distributed - if getenv("DIST"): - Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)] + with Tensor.train(): + while i <= STEPS: + if i%100 == 0 and i > 1: + # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True + corrects = [] + corrects_ema = [] + losses = [] + losses_ema = [] + for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False): + # further split batch if distributed + if getenv("DIST"): + Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)] - correct, loss = eval_step_jitted(model, Xt, Yt) - losses.append(loss.numpy().tolist()) - corrects.extend(correct.numpy().tolist()) - if model_ema: - correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt) - losses_ema.append(loss_ema.numpy().tolist()) - corrects_ema.extend(correct_ema.numpy().tolist()) - - # collect accuracy across ranks - correct_sum, correct_len = sum(corrects), len(corrects) - if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema) - if getenv("DIST"): - if rank == 0: - for j in range(1, min(world_size, 5)): - if model_ema: - recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j) - else: - recv_sum, recv_len = OOB.recv(j) - correct_sum += recv_sum - correct_len += recv_len - if model_ema: - correct_sum_ema += recv_sum_ema - correct_len_ema += recv_len_ema - elif rank < min(world_size, 5): + correct, loss = eval_step_jitted(model, Xt, Yt) + losses.append(loss.numpy().tolist()) + corrects.extend(correct.numpy().tolist()) if model_ema: - OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0) - else: - OOB.send((correct_sum, correct_len), 0) + correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt) + losses_ema.append(loss_ema.numpy().tolist()) + corrects_ema.extend(correct_ema.numpy().tolist()) - # only rank 0 prints - if rank == 0: - acc = correct_sum/correct_len*100.0 - if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0 - if acc > best_eval: - best_eval = acc - print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}") - if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}") + # collect accuracy across ranks + correct_sum, correct_len = sum(corrects), len(corrects) + if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema) + if getenv("DIST"): + if rank == 0: + for j in range(1, min(world_size, 5)): + if model_ema: + recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j) + else: + recv_sum, recv_len = OOB.recv(j) + correct_sum += recv_sum + correct_len += recv_len + if model_ema: + correct_sum_ema += recv_sum_ema + correct_len_ema += recv_len_ema + elif rank < min(world_size, 5): + if model_ema: + OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0) + else: + OOB.send((correct_sum, correct_len), 0) - if STEPS == 0 or i==STEPS: break - X, Y = next(batcher) - # further split batch if distributed - if getenv("DIST"): - X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank] - GlobalCounters.reset() - st = time.monotonic() - loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y) - et = time.monotonic() - loss_cpu = loss.numpy() - # EMA for network weights - if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0: - if model_ema is None: - model_ema = modelEMA(W, model) - model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']])) - cl = time.monotonic() - if not getenv("DIST"): - print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") - else: - print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") - i += 1 + # only rank 0 prints + if rank == 0: + acc = correct_sum/correct_len*100.0 + if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0 + if acc > best_eval: + best_eval = acc + print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}") + if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}") + + if STEPS == 0 or i==STEPS: break + X, Y = next(batcher) + # further split batch if distributed + if getenv("DIST"): + X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank] + GlobalCounters.reset() + st = time.monotonic() + loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y) + et = time.monotonic() + loss_cpu = loss.numpy() + # EMA for network weights + if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0: + if model_ema is None: + model_ema = modelEMA(W, model) + model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']])) + cl = time.monotonic() + if not getenv("DIST"): + print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") + else: + print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") + i += 1 if __name__ == "__main__": if not getenv("DIST"): diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 8db72fc15c..4ad8742366 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -26,12 +26,11 @@ def train_maskrcnn(): pass if __name__ == "__main__": - Tensor.training = True - - for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): - nm = f"train_{m}" - if nm in globals(): - print(f"training {m}") - globals()[nm]() + with Tensor.train(): + for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): + nm = f"train_{m}" + if nm in globals(): + print(f"training {m}") + globals()[nm]() diff --git a/examples/simple_conv_bn.py b/examples/simple_conv_bn.py index 16182c26af..7d5add4da7 100644 --- a/examples/simple_conv_bn.py +++ b/examples/simple_conv_bn.py @@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d from tinygrad.nn.state import get_parameters if __name__ == "__main__": - Tensor.training = True + with Tensor.train(): - BS, C1, H, W = 4, 16, 224, 224 - C2, K, S, P = 64, 7, 2, 1 + BS, C1, H, W = 4, 16, 224, 224 + C2, K, S, P = 64, 7, 2, 1 - x = Tensor.uniform(BS, C1, H, W) - conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) - bn = BatchNorm2d(C2, track_running_stats=False) - for t in get_parameters([x, conv, bn]): t.realize() + x = Tensor.uniform(BS, C1, H, W) + conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) + bn = BatchNorm2d(C2, track_running_stats=False) + for t in get_parameters([x, conv, bn]): t.realize() - print("running network") - x.sequential([conv, bn]).numpy() + print("running network") + x.sequential([conv, bn]).numpy() diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index 98a4612cda..da1c76b79e 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -61,43 +61,43 @@ if __name__ == "__main__": else: X_train, Y_train = fetch_cifar() - Tensor.training = True - for i in (t := trange(steps)): - if IMAGENET: - X, Y = q.get(True) - else: - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - X, Y = X_train[samp], Y_train[samp] + with Tensor.train() + for i in (t := trange(steps)): + if IMAGENET: + X, Y = q.get(True) + else: + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + X, Y = X_train[samp], Y_train[samp] - st = time.time() - out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) - fp_time = (time.time()-st)*1000.0 + st = time.time() + out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) + fp_time = (time.time()-st)*1000.0 - y = np.zeros((BS,classes), np.float32) - y[range(y.shape[0]),Y] = -classes - y = Tensor(y, requires_grad=False) - loss = out.log_softmax().mul(y).mean() + y = np.zeros((BS,classes), np.float32) + y[range(y.shape[0]),Y] = -classes + y = Tensor(y, requires_grad=False) + loss = out.log_softmax().mul(y).mean() - optimizer.zero_grad() + optimizer.zero_grad() - st = time.time() - loss.backward() - bp_time = (time.time()-st)*1000.0 + st = time.time() + loss.backward() + bp_time = (time.time()-st)*1000.0 - st = time.time() - optimizer.step() - opt_time = (time.time()-st)*1000.0 + st = time.time() + optimizer.step() + opt_time = (time.time()-st)*1000.0 - st = time.time() - loss = loss.numpy() - cat = out.argmax(axis=1).numpy() - accuracy = (cat == Y).mean() - finish_time = (time.time()-st)*1000.0 + st = time.time() + loss = loss.numpy() + cat = out.argmax(axis=1).numpy() + accuracy = (cat == Y).mean() + finish_time = (time.time()-st)*1000.0 - # printing - t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % - (loss, accuracy, - fp_time, bp_time, opt_time, finish_time, - fp_time + bp_time + opt_time + finish_time)) + # printing + t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % + (loss, accuracy, + fp_time, bp_time, opt_time, finish_time, + fp_time + bp_time + opt_time + finish_time)) - del out, y, loss + del out, y, loss diff --git a/extra/training.py b/extra/training.py index b6c36a3918..132a10da81 100644 --- a/extra/training.py +++ b/extra/training.py @@ -5,31 +5,31 @@ from tinygrad.helpers import getenv def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y), transform=lambda x: x, target_transform=lambda x: x, noloss=False): - Tensor.training = True - losses, accuracies = [], [] - for i in (t := trange(steps, disable=getenv('CI', False))): - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - x = Tensor(transform(X_train[samp]), requires_grad=False) - y = Tensor(target_transform(Y_train[samp])) + with Tensor.train(): + losses, accuracies = [], [] + for i in (t := trange(steps, disable=getenv('CI', False))): + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + x = Tensor(transform(X_train[samp]), requires_grad=False) + y = Tensor(target_transform(Y_train[samp])) - # network - out = model.forward(x) if hasattr(model, 'forward') else model(x) + # network + out = model.forward(x) if hasattr(model, 'forward') else model(x) - loss = lossfn(out, y) - optim.zero_grad() - loss.backward() - if noloss: del loss - optim.step() + loss = lossfn(out, y) + optim.zero_grad() + loss.backward() + if noloss: del loss + optim.step() - # printing - if not noloss: - cat = out.argmax(axis=-1) - accuracy = (cat == y).mean().numpy() + # printing + if not noloss: + cat = out.argmax(axis=-1) + accuracy = (cat == y).mean().numpy() - loss = loss.detach().numpy() - losses.append(loss) - accuracies.append(accuracy) - t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) + loss = loss.detach().numpy() + losses.append(loss) + accuracies.append(accuracy) + t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) return [losses, accuracies] diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index c8c8e8db78..48dc84e490 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -171,73 +171,68 @@ class TestOpt(unittest.TestCase): assert ret.numpy()[0] == 33 def test_fold_batchnorm(self): - # TODO: with Tensor.training - Tensor.training = True - img = Tensor.ones(1,32,4,4) - bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache(): - img_bn = bn(img).realize() - print(img_bn) - assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}" - Tensor.training = False + with Tensor.train(): + img = Tensor.ones(1,32,4,4) + bn = nn.BatchNorm2d(32, track_running_stats=False) + with CLCache(): + img_bn = bn(img).realize() + print(img_bn) + assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}" + # Tensor.training = False def test_fold_conv_sgd(self): - # TODO: with Tensor.training - Tensor.training = True - img = Tensor.ones(2,3,4,4) - c1 = nn.Conv2d(3,32,3) - opt = optim.SGD(get_parameters(c1)) - with CLCache(): - opt.zero_grad() - c1(img).relu().sum().backward() - opt.step() - # TODO: this should be 4, but the sum output child stays around - # with pushing_permutes it can be 3 - # TODO: broken with optim fixes - assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}" - Tensor.training = False + with Tensor.train(): + img = Tensor.ones(2,3,4,4) + c1 = nn.Conv2d(3,32,3) + opt = optim.SGD(get_parameters(c1)) + with CLCache(): + opt.zero_grad() + c1(img).relu().sum().backward() + opt.step() + # TODO: this should be 4, but the sum output child stays around + # with pushing_permutes it can be 3 + # TODO: broken with optim fixes + assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}" + # Tensor.training = False def test_fold_2convs_sgd(self): - # TODO: with Tensor.training - Tensor.training = True - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,16,3,bias=False) - c2 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2])) - with CLCache(allowed=9): - opt.zero_grad() - c2(c1(img).relu()).relu().sum().backward() - opt.step() - Tensor.training = False + with Tensor.train(): + img = Tensor.ones(2,3,64,64) + c1 = nn.Conv2d(3,16,3,bias=False) + c2 = nn.Conv2d(16,32,3,bias=False) + opt = optim.SGD(get_parameters([c1, c2])) + with CLCache(allowed=9): + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + opt.step() + # Tensor.training = False def test_fold_4convs_sgd(self): - # TODO: with Tensor.training - Tensor.training = True - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,4,3,bias=False) - c2 = nn.Conv2d(4,8,3,bias=False) - c3 = nn.Conv2d(8,16,3,bias=False) - c4 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2, c3, c4])) - with CLCache(allowed=19): - opt.zero_grad() - c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - opt.step() - Tensor.training = False + with Tensor.train(): + img = Tensor.ones(2,3,64,64) + c1 = nn.Conv2d(3,4,3,bias=False) + c2 = nn.Conv2d(4,8,3,bias=False) + c3 = nn.Conv2d(8,16,3,bias=False) + c4 = nn.Conv2d(16,32,3,bias=False) + opt = optim.SGD(get_parameters([c1, c2, c3, c4])) + with CLCache(allowed=19): + opt.zero_grad() + c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() + opt.step() + # Tensor.training = False def test_fold_conv_batchnorm_sgd(self): - # TODO: with Tensor.training - Tensor.training = True - img = Tensor.ones(1,3,4,4) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - opt = optim.SGD(get_parameters([c1, bn])) - with CLCache(allowed=18): # this is too high - img_bn = bn(c1(img)).elu().sum() - opt.zero_grad() - img_bn.backward() - opt.step() - Tensor.training = False + with Tensor.train(): + img = Tensor.ones(1,3,4,4) + c1 = nn.Conv2d(3,32,3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + opt = optim.SGD(get_parameters([c1, bn])) + with CLCache(allowed=18): # this is too high + img_bn = bn(c1(img)).elu().sum() + opt.zero_grad() + img_bn.backward() + opt.step() + # Tensor.training = False def test_fold_conv_batchnorm_notrain(self): img = Tensor.ones(1,3,8,8) @@ -250,15 +245,14 @@ class TestOpt(unittest.TestCase): assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}" def test_fold_conv_batchnorm(self): - Tensor.training = True - img = Tensor.ones(1,3,8,8) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache(): - img_conv = bn(c1(img)).relu().realize() - print(img_conv) - assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}" - Tensor.training = False + with Tensor.train(): + img = Tensor.ones(1,3,8,8) + c1 = nn.Conv2d(3,32,3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + with CLCache(): + img_conv = bn(c1(img)).relu().realize() + print(img_conv) + assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}" def test_fold_conv_elu(self): img = Tensor.ones(1,4,8,8) diff --git a/test/external/graph_batchnorm.py b/test/external/graph_batchnorm.py index f0813edf35..59e3b7961a 100644 --- a/test/external/graph_batchnorm.py +++ b/test/external/graph_batchnorm.py @@ -4,15 +4,14 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, BatchNorm2d, optim def model_step(lm): - Tensor.training = True - x = Tensor.ones(8,12,128,256, requires_grad=False) - optimizer = optim.SGD(get_parameters(lm), lr=0.001) - loss = lm.forward(x).sum() - optimizer.zero_grad() - loss.backward() - del x,loss - optimizer.step() - Tensor.training = False + with Tensor.train(): + x = Tensor.ones(8,12,128,256, requires_grad=False) + optimizer = optim.SGD(get_parameters(lm), lr=0.001) + loss = lm.forward(x).sum() + optimizer.zero_grad() + loss.backward() + del x,loss + optimizer.step() class TestBatchnorm(unittest.TestCase): def test_conv(self): diff --git a/test/models/test_end2end.py b/test/models/test_end2end.py index b8e8f53a66..92f58085a3 100644 --- a/test/models/test_end2end.py +++ b/test/models/test_end2end.py @@ -9,50 +9,50 @@ from extra.datasets import fetch_mnist from tinygrad.helpers import CI def compare_tiny_torch(model, model_torch, X, Y): - Tensor.training = True - model_torch.train() - model_state_dict = get_state_dict(model) - for k,v in model_torch.named_parameters(): - if not CI: print(f"initting {k} from torch") - model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() + with Tensor.train(): + model_torch.train() + model_state_dict = get_state_dict(model) + for k,v in model_torch.named_parameters(): + if not CI: print(f"initting {k} from torch") + model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() - optimizer = optim.SGD(get_parameters(model), lr=0.01) - optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01) + optimizer = optim.SGD(get_parameters(model), lr=0.01) + optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01) - Xt = torch.Tensor(X.numpy()) - np.testing.assert_allclose(X.numpy(), Xt.detach().numpy()) + Xt = torch.Tensor(X.numpy()) + np.testing.assert_allclose(X.numpy(), Xt.detach().numpy()) - out = model(X) - loss = (out * Y).mean() - if not CI: print(loss.realize().numpy()) + out = model(X) + loss = (out * Y).mean() + if not CI: print(loss.realize().numpy()) - out_torch = model_torch(torch.Tensor(X.numpy())) - loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() - if not CI: print(loss_torch.detach().numpy()) + out_torch = model_torch(torch.Tensor(X.numpy())) + loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() + if not CI: print(loss_torch.detach().numpy()) - # assert losses match - np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) + # assert losses match + np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) - # zero and backward - optimizer.zero_grad() - loss.backward() - optimizer_torch.zero_grad() - loss_torch.backward() + # zero and backward + optimizer.zero_grad() + loss.backward() + optimizer_torch.zero_grad() + loss_torch.backward() - for k,v in list(model_torch.named_parameters())[::-1]: - g = model_state_dict[k].grad.numpy() - gt = v.grad.detach().numpy() - if not CI: print("testing grads", k) - np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}') + for k,v in list(model_torch.named_parameters())[::-1]: + g = model_state_dict[k].grad.numpy() + gt = v.grad.detach().numpy() + if not CI: print("testing grads", k) + np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}') - # take the steps - optimizer.step() - optimizer_torch.step() + # take the steps + optimizer.step() + optimizer_torch.step() - # assert weights match (they don't!) - for k,v in model_torch.named_parameters(): - if not CI: print("testing weight", k) - np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}') + # assert weights match (they don't!) + for k,v in model_torch.named_parameters(): + if not CI: print("testing weight", k) + np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}') def get_mnist_data(): X_train, Y_train, X_test, Y_test = fetch_mnist() diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 1bf0516992..a58fcbb4ed 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -122,27 +122,24 @@ class TestRealWorld(unittest.TestCase): #Device.DEFAULT = "FAKE" #Device['fake'].codegen = Device[old_default].codegen - # TODO: with train - old_training = Tensor.training - Tensor.training = True - model = SpeedyResNet(Tensor.ones((12,3,2,2))) - optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) + with Tensor.train(): + model = SpeedyResNet(Tensor.ones((12,3,2,2))) + optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) - BS = 32 if CI else 512 + BS = 32 if CI else 512 - @TinyJit - def train(X): - out = model(X) - loss = out.mean() - optimizer.zero_grad() - loss.backward() - optimizer.step() + @TinyJit + def train(X): + out = model(X) + loss = out.mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() - helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal + helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal - # reset device - Tensor.training = old_training - #Device.DEFAULT = old_default + # reset device + #Device.DEFAULT = old_default if __name__ == '__main__': unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index 8d2f0cf013..ccbac47541 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase): #@unittest.skip("may want to reconsider this") def test_fold_batchnorm(self): - Tensor.training = True - img = Tensor.empty(1,32,4,4) - bn = nn.BatchNorm2d(32, track_running_stats=False) - out = bn(img) - check_schedule(out, 3) - Tensor.training = False + with Tensor.train(): + img = Tensor.empty(1,32,4,4) + bn = nn.BatchNorm2d(32, track_running_stats=False) + out = bn(img) + check_schedule(out, 3) def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 3adf9b7e8d..331fed289c 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -57,11 +57,11 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_attention_training(self): - Tensor.training = True - self.test_attention(dropout_p=0.0) - with self.assertRaises(AssertionError): - # symbolic shape dropout is not supported - self.test_attention(dropout_p=0.5) + with Tensor.train(): + self.test_attention(dropout_p=0.0) + with self.assertRaises(AssertionError): + # symbolic shape dropout is not supported + self.test_attention(dropout_p=0.5) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() diff --git a/test/test_tensor.py b/test/test_tensor.py index d60ba5e514..199fe3a6f2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -97,12 +97,12 @@ class TestTinygrad(unittest.TestCase): assert W.grad is not None def test_dropout(self): - Tensor.training = True - n, rate = 1_000_000, 0.1 - w = Tensor.ones(n).dropout(rate) - non_zeros = np.count_nonzero(w.numpy()) - expected = n * (1 - rate) - np.testing.assert_allclose(non_zeros, expected, rtol=2e-3) + with Tensor.train(): + n, rate = 1_000_000, 0.1 + w = Tensor.ones(n).dropout(rate) + non_zeros = np.count_nonzero(w.numpy()) + expected = n * (1 - rate) + np.testing.assert_allclose(non_zeros, expected, rtol=2e-3) def test_jacobian(self): W = np.random.RandomState(42069).random((10, 5)).astype(np.float32) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b7afef8a0e..469463840d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -5,7 +5,7 @@ from collections import defaultdict from functools import partialmethod, reduce from itertools import accumulate import numpy as np -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int from tinygrad.lazy import LazyBuffer @@ -38,6 +38,12 @@ class Tensor: __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) training: ClassVar[bool] = False + class train: + def __enter__(self): + self.prev = Tensor.training + Tensor.training = True + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev + no_grad: ClassVar[bool] = False default_type: ClassVar[DType] = dtypes.float32 def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):