with Tensor.train() (#1935)

* add with.train

* remove the rest TODOs

* fix pyflake

* fix pyflake error

* fix mypy
This commit is contained in:
Yixiang Gao
2023-09-28 20:02:31 -05:00
committed by GitHub
parent 10f0dc0c85
commit 094d3d71be
14 changed files with 305 additions and 317 deletions

View File

@@ -131,12 +131,6 @@ Training neural networks in tinygrad is super simple.
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients. All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py). They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
First we need to set the training flag in `Tensor`:
```python
Tensor.training = True
```
For our loss function we will be using sparse categorical cross entropy loss. For our loss function we will be using sparse categorical cross entropy loss.
```python ```python
@@ -176,37 +170,41 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network. Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64. We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager.
```python ```python
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()
for step in range(1000): with Tensor.train():
# random sample a batch for step in range(1000):
samp = np.random.randint(0, X_train.shape[0], size=(64)) # random sample a batch
batch = Tensor(X_train[samp], requires_grad=False) samp = np.random.randint(0, X_train.shape[0], size=(64))
# get the corresponding labels batch = Tensor(X_train[samp], requires_grad=False)
labels = Tensor(Y_train[samp]) # get the corresponding labels
labels = Tensor(Y_train[samp])
# forward pass # forward pass
out = net(batch) out = net(batch)
# compute loss # compute loss
loss = sparse_categorical_crossentropy(out, labels) loss = sparse_categorical_crossentropy(out, labels)
# zero gradients # zero gradients
opt.zero_grad() opt.zero_grad()
# backward pass # backward pass
loss.backward() loss.backward()
# update parameters # update parameters
opt.step() opt.step()
# calculate accuracy # calculate accuracy
pred = out.argmax(axis=-1) pred = out.argmax(axis=-1)
acc = (pred == labels).mean() acc = (pred == labels).mean()
if step % 100 == 0: if step % 100 == 0:
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}") print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
``` ```
## Evaluation ## Evaluation
@@ -215,9 +213,6 @@ Now that we have trained our neural network we can evaluate it on the test set.
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches. We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
```python ```python
# set training flag to false
Tensor.training = False
with Timing("Time: "): with Timing("Time: "):
avg_acc = 0 avg_acc = 0
for step in range(1000): for step in range(1000):

View File

@@ -22,7 +22,7 @@ from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit from tinygrad.jit import TinyJit
from extra.dist import collectives from extra.dist import collectives
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 100) BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
class BatchNorm(nn.BatchNorm2d): class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features): def __init__(self, num_features):
@@ -234,7 +234,6 @@ def train_cifar():
# this import needs to be done here because this is running in a subprocess # this import needs to be done here because this is running in a subprocess
from extra.dist import OOB from extra.dist import OOB
Tensor.training = True
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1) rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
X_train, Y_train, X_test, Y_test = fetch_cifar() X_train, Y_train, X_test, Y_test = fetch_cifar()
@@ -312,8 +311,7 @@ def train_cifar():
eval_step_jitted = TinyJit(eval_step) eval_step_jitted = TinyJit(eval_step)
eval_step_ema_jitted = TinyJit(eval_step) eval_step_ema_jitted = TinyJit(eval_step)
# 97 steps in 2 seconds = 20ms / step Tensor.training = True # 97 steps in 2 seconds = 20ms / step
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136 # step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68 # 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1 # 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
@@ -327,77 +325,78 @@ def train_cifar():
best_eval = -1 best_eval = -1
i = 0 i = 0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
while i <= STEPS: with Tensor.train():
if i%100 == 0 and i > 1: while i <= STEPS:
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True if i%100 == 0 and i > 1:
corrects = [] # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects_ema = [] corrects = []
losses = [] corrects_ema = []
losses_ema = [] losses = []
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False): losses_ema = []
# further split batch if distributed for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
if getenv("DIST"): # further split batch if distributed
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)] if getenv("DIST"):
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
correct, loss = eval_step_jitted(model, Xt, Yt) correct, loss = eval_step_jitted(model, Xt, Yt)
losses.append(loss.numpy().tolist()) losses.append(loss.numpy().tolist())
corrects.extend(correct.numpy().tolist()) corrects.extend(correct.numpy().tolist())
if model_ema:
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
losses_ema.append(loss_ema.numpy().tolist())
corrects_ema.extend(correct_ema.numpy().tolist())
# collect accuracy across ranks
correct_sum, correct_len = sum(corrects), len(corrects)
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
if getenv("DIST"):
if rank == 0:
for j in range(1, min(world_size, 5)):
if model_ema:
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
else:
recv_sum, recv_len = OOB.recv(j)
correct_sum += recv_sum
correct_len += recv_len
if model_ema:
correct_sum_ema += recv_sum_ema
correct_len_ema += recv_len_ema
elif rank < min(world_size, 5):
if model_ema: if model_ema:
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0) correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
else: losses_ema.append(loss_ema.numpy().tolist())
OOB.send((correct_sum, correct_len), 0) corrects_ema.extend(correct_ema.numpy().tolist())
# only rank 0 prints # collect accuracy across ranks
if rank == 0: correct_sum, correct_len = sum(corrects), len(corrects)
acc = correct_sum/correct_len*100.0 if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0 if getenv("DIST"):
if acc > best_eval: if rank == 0:
best_eval = acc for j in range(1, min(world_size, 5)):
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}") if model_ema:
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}") recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
else:
recv_sum, recv_len = OOB.recv(j)
correct_sum += recv_sum
correct_len += recv_len
if model_ema:
correct_sum_ema += recv_sum_ema
correct_len_ema += recv_len_ema
elif rank < min(world_size, 5):
if model_ema:
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
else:
OOB.send((correct_sum, correct_len), 0)
if STEPS == 0 or i==STEPS: break # only rank 0 prints
X, Y = next(batcher) if rank == 0:
# further split batch if distributed acc = correct_sum/correct_len*100.0
if getenv("DIST"): if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank] if acc > best_eval:
GlobalCounters.reset() best_eval = acc
st = time.monotonic() print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y) if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
et = time.monotonic()
loss_cpu = loss.numpy() if STEPS == 0 or i==STEPS: break
# EMA for network weights X, Y = next(batcher)
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0: # further split batch if distributed
if model_ema is None: if getenv("DIST"):
model_ema = modelEMA(W, model) X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']])) GlobalCounters.reset()
cl = time.monotonic() st = time.monotonic()
if not getenv("DIST"): loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") et = time.monotonic()
else: loss_cpu = loss.numpy()
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") # EMA for network weights
i += 1 if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
if model_ema is None:
model_ema = modelEMA(W, model)
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic()
if not getenv("DIST"):
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
i += 1
if __name__ == "__main__": if __name__ == "__main__":
if not getenv("DIST"): if not getenv("DIST"):

View File

@@ -26,12 +26,11 @@ def train_maskrcnn():
pass pass
if __name__ == "__main__": if __name__ == "__main__":
Tensor.training = True with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): nm = f"train_{m}"
nm = f"train_{m}" if nm in globals():
if nm in globals(): print(f"training {m}")
print(f"training {m}") globals()[nm]()
globals()[nm]()

View File

@@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
if __name__ == "__main__": if __name__ == "__main__":
Tensor.training = True with Tensor.train():
BS, C1, H, W = 4, 16, 224, 224 BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1 C2, K, S, P = 64, 7, 2, 1
x = Tensor.uniform(BS, C1, H, W) x = Tensor.uniform(BS, C1, H, W)
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
bn = BatchNorm2d(C2, track_running_stats=False) bn = BatchNorm2d(C2, track_running_stats=False)
for t in get_parameters([x, conv, bn]): t.realize() for t in get_parameters([x, conv, bn]): t.realize()
print("running network") print("running network")
x.sequential([conv, bn]).numpy() x.sequential([conv, bn]).numpy()

View File

@@ -61,43 +61,43 @@ if __name__ == "__main__":
else: else:
X_train, Y_train = fetch_cifar() X_train, Y_train = fetch_cifar()
Tensor.training = True with Tensor.train()
for i in (t := trange(steps)): for i in (t := trange(steps)):
if IMAGENET: if IMAGENET:
X, Y = q.get(True) X, Y = q.get(True)
else: else:
samp = np.random.randint(0, X_train.shape[0], size=(BS)) samp = np.random.randint(0, X_train.shape[0], size=(BS))
X, Y = X_train[samp], Y_train[samp] X, Y = X_train[samp], Y_train[samp]
st = time.time() st = time.time()
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
fp_time = (time.time()-st)*1000.0 fp_time = (time.time()-st)*1000.0
y = np.zeros((BS,classes), np.float32) y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes y[range(y.shape[0]),Y] = -classes
y = Tensor(y, requires_grad=False) y = Tensor(y, requires_grad=False)
loss = out.log_softmax().mul(y).mean() loss = out.log_softmax().mul(y).mean()
optimizer.zero_grad() optimizer.zero_grad()
st = time.time() st = time.time()
loss.backward() loss.backward()
bp_time = (time.time()-st)*1000.0 bp_time = (time.time()-st)*1000.0
st = time.time() st = time.time()
optimizer.step() optimizer.step()
opt_time = (time.time()-st)*1000.0 opt_time = (time.time()-st)*1000.0
st = time.time() st = time.time()
loss = loss.numpy() loss = loss.numpy()
cat = out.argmax(axis=1).numpy() cat = out.argmax(axis=1).numpy()
accuracy = (cat == Y).mean() accuracy = (cat == Y).mean()
finish_time = (time.time()-st)*1000.0 finish_time = (time.time()-st)*1000.0
# printing # printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
(loss, accuracy, (loss, accuracy,
fp_time, bp_time, opt_time, finish_time, fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time)) fp_time + bp_time + opt_time + finish_time))
del out, y, loss del out, y, loss

View File

@@ -5,31 +5,31 @@ from tinygrad.helpers import getenv
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y), def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x, target_transform=lambda x: x, noloss=False): transform=lambda x: x, target_transform=lambda x: x, noloss=False):
Tensor.training = True with Tensor.train():
losses, accuracies = [], [] losses, accuracies = [], []
for i in (t := trange(steps, disable=getenv('CI', False))): for i in (t := trange(steps, disable=getenv('CI', False))):
samp = np.random.randint(0, X_train.shape[0], size=(BS)) samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False) x = Tensor(transform(X_train[samp]), requires_grad=False)
y = Tensor(target_transform(Y_train[samp])) y = Tensor(target_transform(Y_train[samp]))
# network # network
out = model.forward(x) if hasattr(model, 'forward') else model(x) out = model.forward(x) if hasattr(model, 'forward') else model(x)
loss = lossfn(out, y) loss = lossfn(out, y)
optim.zero_grad() optim.zero_grad()
loss.backward() loss.backward()
if noloss: del loss if noloss: del loss
optim.step() optim.step()
# printing # printing
if not noloss: if not noloss:
cat = out.argmax(axis=-1) cat = out.argmax(axis=-1)
accuracy = (cat == y).mean().numpy() accuracy = (cat == y).mean().numpy()
loss = loss.detach().numpy() loss = loss.detach().numpy()
losses.append(loss) losses.append(loss)
accuracies.append(accuracy) accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
return [losses, accuracies] return [losses, accuracies]

View File

@@ -171,73 +171,68 @@ class TestOpt(unittest.TestCase):
assert ret.numpy()[0] == 33 assert ret.numpy()[0] == 33
def test_fold_batchnorm(self): def test_fold_batchnorm(self):
# TODO: with Tensor.training with Tensor.train():
Tensor.training = True img = Tensor.ones(1,32,4,4)
img = Tensor.ones(1,32,4,4) bn = nn.BatchNorm2d(32, track_running_stats=False)
bn = nn.BatchNorm2d(32, track_running_stats=False) with CLCache():
with CLCache(): img_bn = bn(img).realize()
img_bn = bn(img).realize() print(img_bn)
print(img_bn) assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}" # Tensor.training = False
Tensor.training = False
def test_fold_conv_sgd(self): def test_fold_conv_sgd(self):
# TODO: with Tensor.training with Tensor.train():
Tensor.training = True img = Tensor.ones(2,3,4,4)
img = Tensor.ones(2,3,4,4) c1 = nn.Conv2d(3,32,3)
c1 = nn.Conv2d(3,32,3) opt = optim.SGD(get_parameters(c1))
opt = optim.SGD(get_parameters(c1)) with CLCache():
with CLCache(): opt.zero_grad()
opt.zero_grad() c1(img).relu().sum().backward()
c1(img).relu().sum().backward() opt.step()
opt.step() # TODO: this should be 4, but the sum output child stays around
# TODO: this should be 4, but the sum output child stays around # with pushing_permutes it can be 3
# with pushing_permutes it can be 3 # TODO: broken with optim fixes
# TODO: broken with optim fixes assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}" # Tensor.training = False
Tensor.training = False
def test_fold_2convs_sgd(self): def test_fold_2convs_sgd(self):
# TODO: with Tensor.training with Tensor.train():
Tensor.training = True img = Tensor.ones(2,3,64,64)
img = Tensor.ones(2,3,64,64) c1 = nn.Conv2d(3,16,3,bias=False)
c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False) opt = optim.SGD(get_parameters([c1, c2]))
opt = optim.SGD(get_parameters([c1, c2])) with CLCache(allowed=9):
with CLCache(allowed=9): opt.zero_grad()
opt.zero_grad() c2(c1(img).relu()).relu().sum().backward()
c2(c1(img).relu()).relu().sum().backward() opt.step()
opt.step() # Tensor.training = False
Tensor.training = False
def test_fold_4convs_sgd(self): def test_fold_4convs_sgd(self):
# TODO: with Tensor.training with Tensor.train():
Tensor.training = True img = Tensor.ones(2,3,64,64)
img = Tensor.ones(2,3,64,64) c1 = nn.Conv2d(3,4,3,bias=False)
c1 = nn.Conv2d(3,4,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c3 = nn.Conv2d(8,16,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False) c4 = nn.Conv2d(16,32,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False) opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
opt = optim.SGD(get_parameters([c1, c2, c3, c4])) with CLCache(allowed=19):
with CLCache(allowed=19): opt.zero_grad()
opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() opt.step()
opt.step() # Tensor.training = False
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self): def test_fold_conv_batchnorm_sgd(self):
# TODO: with Tensor.training with Tensor.train():
Tensor.training = True img = Tensor.ones(1,3,4,4)
img = Tensor.ones(1,3,4,4) c1 = nn.Conv2d(3,32,3)
c1 = nn.Conv2d(3,32,3) bn = nn.BatchNorm2d(32, track_running_stats=False)
bn = nn.BatchNorm2d(32, track_running_stats=False) opt = optim.SGD(get_parameters([c1, bn]))
opt = optim.SGD(get_parameters([c1, bn])) with CLCache(allowed=18): # this is too high
with CLCache(allowed=18): # this is too high img_bn = bn(c1(img)).elu().sum()
img_bn = bn(c1(img)).elu().sum() opt.zero_grad()
opt.zero_grad() img_bn.backward()
img_bn.backward() opt.step()
opt.step() # Tensor.training = False
Tensor.training = False
def test_fold_conv_batchnorm_notrain(self): def test_fold_conv_batchnorm_notrain(self):
img = Tensor.ones(1,3,8,8) img = Tensor.ones(1,3,8,8)
@@ -250,15 +245,14 @@ class TestOpt(unittest.TestCase):
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}" assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
def test_fold_conv_batchnorm(self): def test_fold_conv_batchnorm(self):
Tensor.training = True with Tensor.train():
img = Tensor.ones(1,3,8,8) img = Tensor.ones(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache(): with CLCache():
img_conv = bn(c1(img)).relu().realize() img_conv = bn(c1(img)).relu().realize()
print(img_conv) print(img_conv)
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}" assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
Tensor.training = False
def test_fold_conv_elu(self): def test_fold_conv_elu(self):
img = Tensor.ones(1,4,8,8) img = Tensor.ones(1,4,8,8)

View File

@@ -4,15 +4,14 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d, optim from tinygrad.nn import Conv2d, BatchNorm2d, optim
def model_step(lm): def model_step(lm):
Tensor.training = True with Tensor.train():
x = Tensor.ones(8,12,128,256, requires_grad=False) x = Tensor.ones(8,12,128,256, requires_grad=False)
optimizer = optim.SGD(get_parameters(lm), lr=0.001) optimizer = optim.SGD(get_parameters(lm), lr=0.001)
loss = lm.forward(x).sum() loss = lm.forward(x).sum()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
del x,loss del x,loss
optimizer.step() optimizer.step()
Tensor.training = False
class TestBatchnorm(unittest.TestCase): class TestBatchnorm(unittest.TestCase):
def test_conv(self): def test_conv(self):

View File

@@ -9,50 +9,50 @@ from extra.datasets import fetch_mnist
from tinygrad.helpers import CI from tinygrad.helpers import CI
def compare_tiny_torch(model, model_torch, X, Y): def compare_tiny_torch(model, model_torch, X, Y):
Tensor.training = True with Tensor.train():
model_torch.train() model_torch.train()
model_state_dict = get_state_dict(model) model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters(): for k,v in model_torch.named_parameters():
if not CI: print(f"initting {k} from torch") if not CI: print(f"initting {k} from torch")
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
optimizer = optim.SGD(get_parameters(model), lr=0.01) optimizer = optim.SGD(get_parameters(model), lr=0.01)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01) optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
Xt = torch.Tensor(X.numpy()) Xt = torch.Tensor(X.numpy())
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy()) np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
out = model(X) out = model(X)
loss = (out * Y).mean() loss = (out * Y).mean()
if not CI: print(loss.realize().numpy()) if not CI: print(loss.realize().numpy())
out_torch = model_torch(torch.Tensor(X.numpy())) out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
if not CI: print(loss_torch.detach().numpy()) if not CI: print(loss_torch.detach().numpy())
# assert losses match # assert losses match
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
# zero and backward # zero and backward
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer_torch.zero_grad() optimizer_torch.zero_grad()
loss_torch.backward() loss_torch.backward()
for k,v in list(model_torch.named_parameters())[::-1]: for k,v in list(model_torch.named_parameters())[::-1]:
g = model_state_dict[k].grad.numpy() g = model_state_dict[k].grad.numpy()
gt = v.grad.detach().numpy() gt = v.grad.detach().numpy()
if not CI: print("testing grads", k) if not CI: print("testing grads", k)
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}') np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
# take the steps # take the steps
optimizer.step() optimizer.step()
optimizer_torch.step() optimizer_torch.step()
# assert weights match (they don't!) # assert weights match (they don't!)
for k,v in model_torch.named_parameters(): for k,v in model_torch.named_parameters():
if not CI: print("testing weight", k) if not CI: print("testing weight", k)
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}') np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
def get_mnist_data(): def get_mnist_data():
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()

View File

@@ -122,27 +122,24 @@ class TestRealWorld(unittest.TestCase):
#Device.DEFAULT = "FAKE" #Device.DEFAULT = "FAKE"
#Device['fake'].codegen = Device[old_default].codegen #Device['fake'].codegen = Device[old_default].codegen
# TODO: with train with Tensor.train():
old_training = Tensor.training model = SpeedyResNet(Tensor.ones((12,3,2,2)))
Tensor.training = True optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32 if CI else 512 BS = 32 if CI else 512
@TinyJit @TinyJit
def train(X): def train(X):
out = model(X) out = model(X)
loss = out.mean() loss = out.mean()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
# reset device # reset device
Tensor.training = old_training #Device.DEFAULT = old_default
#Device.DEFAULT = old_default
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this") #@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self): def test_fold_batchnorm(self):
Tensor.training = True with Tensor.train():
img = Tensor.empty(1,32,4,4) img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img) out = bn(img)
check_schedule(out, 3) check_schedule(out, 3)
Tensor.training = False
def test_fold_conv_relu(self): def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3) c1 = nn.Conv2d(3,16,3)

View File

@@ -57,11 +57,11 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_attention_training(self): def test_attention_training(self):
Tensor.training = True with Tensor.train():
self.test_attention(dropout_p=0.0) self.test_attention(dropout_p=0.0)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
# symbolic shape dropout is not supported # symbolic shape dropout is not supported
self.test_attention(dropout_p=0.5) self.test_attention(dropout_p=0.5)
def test_cat_dim0(self): def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize() def f(a, b): return a.cat(b, dim=0).realize()

View File

@@ -97,12 +97,12 @@ class TestTinygrad(unittest.TestCase):
assert W.grad is not None assert W.grad is not None
def test_dropout(self): def test_dropout(self):
Tensor.training = True with Tensor.train():
n, rate = 1_000_000, 0.1 n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate) w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy()) non_zeros = np.count_nonzero(w.numpy())
expected = n * (1 - rate) expected = n * (1 - rate)
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3) np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
def test_jacobian(self): def test_jacobian(self):
W = np.random.RandomState(42069).random((10, 5)).astype(np.float32) W = np.random.RandomState(42069).random((10, 5)).astype(np.float32)

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from functools import partialmethod, reduce from functools import partialmethod, reduce
from itertools import accumulate from itertools import accumulate
import numpy as np import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
@@ -38,6 +38,12 @@ class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx" __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',) __deletable__ = ('_ctx',)
training: ClassVar[bool] = False training: ClassVar[bool] = False
class train:
def __enter__(self):
self.prev = Tensor.training
Tensor.training = True
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
no_grad: ClassVar[bool] = False no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32 default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):