with Tensor.train() (#1935)

* add with.train

* remove the rest TODOs

* fix pyflake

* fix pyflake error

* fix mypy
This commit is contained in:
Yixiang Gao
2023-09-28 20:02:31 -05:00
committed by GitHub
parent 10f0dc0c85
commit 094d3d71be
14 changed files with 305 additions and 317 deletions

View File

@@ -131,12 +131,6 @@ Training neural networks in tinygrad is super simple.
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
First we need to set the training flag in `Tensor`:
```python
Tensor.training = True
```
For our loss function we will be using sparse categorical cross entropy loss.
```python
@@ -176,37 +170,41 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager.
```python
X_train, Y_train, X_test, Y_test = fetch_mnist()
for step in range(1000):
# random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64))
batch = Tensor(X_train[samp], requires_grad=False)
# get the corresponding labels
labels = Tensor(Y_train[samp])
with Tensor.train():
for step in range(1000):
# random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64))
batch = Tensor(X_train[samp], requires_grad=False)
# get the corresponding labels
labels = Tensor(Y_train[samp])
# forward pass
out = net(batch)
# forward pass
out = net(batch)
# compute loss
loss = sparse_categorical_crossentropy(out, labels)
# compute loss
loss = sparse_categorical_crossentropy(out, labels)
# zero gradients
opt.zero_grad()
# zero gradients
opt.zero_grad()
# backward pass
loss.backward()
# backward pass
loss.backward()
# update parameters
opt.step()
# update parameters
opt.step()
# calculate accuracy
pred = out.argmax(axis=-1)
acc = (pred == labels).mean()
# calculate accuracy
pred = out.argmax(axis=-1)
acc = (pred == labels).mean()
if step % 100 == 0:
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
if step % 100 == 0:
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
```
## Evaluation
@@ -215,9 +213,6 @@ Now that we have trained our neural network we can evaluate it on the test set.
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
```python
# set training flag to false
Tensor.training = False
with Timing("Time: "):
avg_acc = 0
for step in range(1000):

View File

@@ -22,7 +22,7 @@ from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit
from extra.dist import collectives
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 100)
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features):
@@ -234,7 +234,6 @@ def train_cifar():
# this import needs to be done here because this is running in a subprocess
from extra.dist import OOB
Tensor.training = True
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
X_train, Y_train, X_test, Y_test = fetch_cifar()
@@ -312,8 +311,7 @@ def train_cifar():
eval_step_jitted = TinyJit(eval_step)
eval_step_ema_jitted = TinyJit(eval_step)
# 97 steps in 2 seconds = 20ms / step Tensor.training = True
# 97 steps in 2 seconds = 20ms / step
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
@@ -327,77 +325,78 @@ def train_cifar():
best_eval = -1
i = 0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
while i <= STEPS:
if i%100 == 0 and i > 1:
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
corrects_ema = []
losses = []
losses_ema = []
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
# further split batch if distributed
if getenv("DIST"):
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
with Tensor.train():
while i <= STEPS:
if i%100 == 0 and i > 1:
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
corrects_ema = []
losses = []
losses_ema = []
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
# further split batch if distributed
if getenv("DIST"):
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
correct, loss = eval_step_jitted(model, Xt, Yt)
losses.append(loss.numpy().tolist())
corrects.extend(correct.numpy().tolist())
if model_ema:
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
losses_ema.append(loss_ema.numpy().tolist())
corrects_ema.extend(correct_ema.numpy().tolist())
# collect accuracy across ranks
correct_sum, correct_len = sum(corrects), len(corrects)
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
if getenv("DIST"):
if rank == 0:
for j in range(1, min(world_size, 5)):
if model_ema:
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
else:
recv_sum, recv_len = OOB.recv(j)
correct_sum += recv_sum
correct_len += recv_len
if model_ema:
correct_sum_ema += recv_sum_ema
correct_len_ema += recv_len_ema
elif rank < min(world_size, 5):
correct, loss = eval_step_jitted(model, Xt, Yt)
losses.append(loss.numpy().tolist())
corrects.extend(correct.numpy().tolist())
if model_ema:
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
else:
OOB.send((correct_sum, correct_len), 0)
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
losses_ema.append(loss_ema.numpy().tolist())
corrects_ema.extend(correct_ema.numpy().tolist())
# only rank 0 prints
if rank == 0:
acc = correct_sum/correct_len*100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
if acc > best_eval:
best_eval = acc
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
# collect accuracy across ranks
correct_sum, correct_len = sum(corrects), len(corrects)
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
if getenv("DIST"):
if rank == 0:
for j in range(1, min(world_size, 5)):
if model_ema:
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
else:
recv_sum, recv_len = OOB.recv(j)
correct_sum += recv_sum
correct_len += recv_len
if model_ema:
correct_sum_ema += recv_sum_ema
correct_len_ema += recv_len_ema
elif rank < min(world_size, 5):
if model_ema:
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
else:
OOB.send((correct_sum, correct_len), 0)
if STEPS == 0 or i==STEPS: break
X, Y = next(batcher)
# further split batch if distributed
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
GlobalCounters.reset()
st = time.monotonic()
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
et = time.monotonic()
loss_cpu = loss.numpy()
# EMA for network weights
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
if model_ema is None:
model_ema = modelEMA(W, model)
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic()
if not getenv("DIST"):
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
i += 1
# only rank 0 prints
if rank == 0:
acc = correct_sum/correct_len*100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
if acc > best_eval:
best_eval = acc
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
if STEPS == 0 or i==STEPS: break
X, Y = next(batcher)
# further split batch if distributed
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
GlobalCounters.reset()
st = time.monotonic()
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
et = time.monotonic()
loss_cpu = loss.numpy()
# EMA for network weights
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
if model_ema is None:
model_ema = modelEMA(W, model)
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic()
if not getenv("DIST"):
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
i += 1
if __name__ == "__main__":
if not getenv("DIST"):

View File

@@ -26,12 +26,11 @@ def train_maskrcnn():
pass
if __name__ == "__main__":
Tensor.training = True
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()

View File

@@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d
from tinygrad.nn.state import get_parameters
if __name__ == "__main__":
Tensor.training = True
with Tensor.train():
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
x = Tensor.uniform(BS, C1, H, W)
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
bn = BatchNorm2d(C2, track_running_stats=False)
for t in get_parameters([x, conv, bn]): t.realize()
x = Tensor.uniform(BS, C1, H, W)
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
bn = BatchNorm2d(C2, track_running_stats=False)
for t in get_parameters([x, conv, bn]): t.realize()
print("running network")
x.sequential([conv, bn]).numpy()
print("running network")
x.sequential([conv, bn]).numpy()

View File

@@ -61,43 +61,43 @@ if __name__ == "__main__":
else:
X_train, Y_train = fetch_cifar()
Tensor.training = True
for i in (t := trange(steps)):
if IMAGENET:
X, Y = q.get(True)
else:
samp = np.random.randint(0, X_train.shape[0], size=(BS))
X, Y = X_train[samp], Y_train[samp]
with Tensor.train()
for i in (t := trange(steps)):
if IMAGENET:
X, Y = q.get(True)
else:
samp = np.random.randint(0, X_train.shape[0], size=(BS))
X, Y = X_train[samp], Y_train[samp]
st = time.time()
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
fp_time = (time.time()-st)*1000.0
st = time.time()
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
fp_time = (time.time()-st)*1000.0
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
y = Tensor(y, requires_grad=False)
loss = out.log_softmax().mul(y).mean()
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
y = Tensor(y, requires_grad=False)
loss = out.log_softmax().mul(y).mean()
optimizer.zero_grad()
optimizer.zero_grad()
st = time.time()
loss.backward()
bp_time = (time.time()-st)*1000.0
st = time.time()
loss.backward()
bp_time = (time.time()-st)*1000.0
st = time.time()
optimizer.step()
opt_time = (time.time()-st)*1000.0
st = time.time()
optimizer.step()
opt_time = (time.time()-st)*1000.0
st = time.time()
loss = loss.numpy()
cat = out.argmax(axis=1).numpy()
accuracy = (cat == Y).mean()
finish_time = (time.time()-st)*1000.0
st = time.time()
loss = loss.numpy()
cat = out.argmax(axis=1).numpy()
accuracy = (cat == Y).mean()
finish_time = (time.time()-st)*1000.0
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
(loss, accuracy,
fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time))
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
(loss, accuracy,
fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time))
del out, y, loss
del out, y, loss

View File

@@ -5,31 +5,31 @@ from tinygrad.helpers import getenv
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
Tensor.training = True
losses, accuracies = [], []
for i in (t := trange(steps, disable=getenv('CI', False))):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = Tensor(target_transform(Y_train[samp]))
with Tensor.train():
losses, accuracies = [], []
for i in (t := trange(steps, disable=getenv('CI', False))):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = Tensor(target_transform(Y_train[samp]))
# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)
# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
if noloss: del loss
optim.step()
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
if noloss: del loss
optim.step()
# printing
if not noloss:
cat = out.argmax(axis=-1)
accuracy = (cat == y).mean().numpy()
# printing
if not noloss:
cat = out.argmax(axis=-1)
accuracy = (cat == y).mean().numpy()
loss = loss.detach().numpy()
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
loss = loss.detach().numpy()
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
return [losses, accuracies]

View File

@@ -171,73 +171,68 @@ class TestOpt(unittest.TestCase):
assert ret.numpy()[0] == 33
def test_fold_batchnorm(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache():
img_bn = bn(img).realize()
print(img_bn)
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
Tensor.training = False
with Tensor.train():
img = Tensor.ones(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache():
img_bn = bn(img).realize()
print(img_bn)
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
# Tensor.training = False
def test_fold_conv_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.SGD(get_parameters(c1))
with CLCache():
opt.zero_grad()
c1(img).relu().sum().backward()
opt.step()
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
Tensor.training = False
with Tensor.train():
img = Tensor.ones(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.SGD(get_parameters(c1))
with CLCache():
opt.zero_grad()
c1(img).relu().sum().backward()
opt.step()
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
# Tensor.training = False
def test_fold_2convs_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2]))
with CLCache(allowed=9):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
Tensor.training = False
with Tensor.train():
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2]))
with CLCache(allowed=9):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
# Tensor.training = False
def test_fold_4convs_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
with CLCache(allowed=19):
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
opt.step()
Tensor.training = False
with Tensor.train():
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
with CLCache(allowed=19):
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
opt.step()
# Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = optim.SGD(get_parameters([c1, bn]))
with CLCache(allowed=18): # this is too high
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
opt.step()
Tensor.training = False
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = optim.SGD(get_parameters([c1, bn]))
with CLCache(allowed=18): # this is too high
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
opt.step()
# Tensor.training = False
def test_fold_conv_batchnorm_notrain(self):
img = Tensor.ones(1,3,8,8)
@@ -250,15 +245,14 @@ class TestOpt(unittest.TestCase):
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
def test_fold_conv_batchnorm(self):
Tensor.training = True
img = Tensor.ones(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache():
img_conv = bn(c1(img)).relu().realize()
print(img_conv)
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
Tensor.training = False
with Tensor.train():
img = Tensor.ones(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache():
img_conv = bn(c1(img)).relu().realize()
print(img_conv)
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
def test_fold_conv_elu(self):
img = Tensor.ones(1,4,8,8)

View File

@@ -4,15 +4,14 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d, optim
def model_step(lm):
Tensor.training = True
x = Tensor.ones(8,12,128,256, requires_grad=False)
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
loss = lm.forward(x).sum()
optimizer.zero_grad()
loss.backward()
del x,loss
optimizer.step()
Tensor.training = False
with Tensor.train():
x = Tensor.ones(8,12,128,256, requires_grad=False)
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
loss = lm.forward(x).sum()
optimizer.zero_grad()
loss.backward()
del x,loss
optimizer.step()
class TestBatchnorm(unittest.TestCase):
def test_conv(self):

View File

@@ -9,50 +9,50 @@ from extra.datasets import fetch_mnist
from tinygrad.helpers import CI
def compare_tiny_torch(model, model_torch, X, Y):
Tensor.training = True
model_torch.train()
model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters():
if not CI: print(f"initting {k} from torch")
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
with Tensor.train():
model_torch.train()
model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters():
if not CI: print(f"initting {k} from torch")
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
optimizer = optim.SGD(get_parameters(model), lr=0.01)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
optimizer = optim.SGD(get_parameters(model), lr=0.01)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
Xt = torch.Tensor(X.numpy())
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
Xt = torch.Tensor(X.numpy())
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
out = model(X)
loss = (out * Y).mean()
if not CI: print(loss.realize().numpy())
out = model(X)
loss = (out * Y).mean()
if not CI: print(loss.realize().numpy())
out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
if not CI: print(loss_torch.detach().numpy())
out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
if not CI: print(loss_torch.detach().numpy())
# assert losses match
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
# assert losses match
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
# zero and backward
optimizer.zero_grad()
loss.backward()
optimizer_torch.zero_grad()
loss_torch.backward()
# zero and backward
optimizer.zero_grad()
loss.backward()
optimizer_torch.zero_grad()
loss_torch.backward()
for k,v in list(model_torch.named_parameters())[::-1]:
g = model_state_dict[k].grad.numpy()
gt = v.grad.detach().numpy()
if not CI: print("testing grads", k)
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
for k,v in list(model_torch.named_parameters())[::-1]:
g = model_state_dict[k].grad.numpy()
gt = v.grad.detach().numpy()
if not CI: print("testing grads", k)
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
# take the steps
optimizer.step()
optimizer_torch.step()
# take the steps
optimizer.step()
optimizer_torch.step()
# assert weights match (they don't!)
for k,v in model_torch.named_parameters():
if not CI: print("testing weight", k)
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
# assert weights match (they don't!)
for k,v in model_torch.named_parameters():
if not CI: print("testing weight", k)
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
def get_mnist_data():
X_train, Y_train, X_test, Y_test = fetch_mnist()

View File

@@ -122,27 +122,24 @@ class TestRealWorld(unittest.TestCase):
#Device.DEFAULT = "FAKE"
#Device['fake'].codegen = Device[old_default].codegen
# TODO: with train
old_training = Tensor.training
Tensor.training = True
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
with Tensor.train():
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32 if CI else 512
BS = 32 if CI else 512
@TinyJit
def train(X):
out = model(X)
loss = out.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
@TinyJit
def train(X):
out = model(X)
loss = out.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
# reset device
Tensor.training = old_training
#Device.DEFAULT = old_default
# reset device
#Device.DEFAULT = old_default
if __name__ == '__main__':
unittest.main()

View File

@@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
Tensor.training = True
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
Tensor.training = False
with Tensor.train():
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)

View File

@@ -57,11 +57,11 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_attention_training(self):
Tensor.training = True
self.test_attention(dropout_p=0.0)
with self.assertRaises(AssertionError):
# symbolic shape dropout is not supported
self.test_attention(dropout_p=0.5)
with Tensor.train():
self.test_attention(dropout_p=0.0)
with self.assertRaises(AssertionError):
# symbolic shape dropout is not supported
self.test_attention(dropout_p=0.5)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()

View File

@@ -97,12 +97,12 @@ class TestTinygrad(unittest.TestCase):
assert W.grad is not None
def test_dropout(self):
Tensor.training = True
n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy())
expected = n * (1 - rate)
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
with Tensor.train():
n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy())
expected = n * (1 - rate)
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
def test_jacobian(self):
W = np.random.RandomState(42069).random((10, 5)).astype(np.float32)

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.lazy import LazyBuffer
@@ -38,6 +38,12 @@ class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
class train:
def __enter__(self):
self.prev = Tensor.training
Tensor.training = True
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):