mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
with Tensor.train() (#1935)
* add with.train * remove the rest TODOs * fix pyflake * fix pyflake error * fix mypy
This commit is contained in:
@@ -131,12 +131,6 @@ Training neural networks in tinygrad is super simple.
|
||||
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
|
||||
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
|
||||
|
||||
First we need to set the training flag in `Tensor`:
|
||||
|
||||
```python
|
||||
Tensor.training = True
|
||||
```
|
||||
|
||||
For our loss function we will be using sparse categorical cross entropy loss.
|
||||
|
||||
```python
|
||||
@@ -176,37 +170,41 @@ from extra.datasets import fetch_mnist
|
||||
Now we have everything we need to start training our neural network.
|
||||
We will be training for 1000 steps with a batch size of 64.
|
||||
|
||||
We use `with Tensor.train()` set the internal flag `Tensor.training` to `True` during training.
|
||||
Upon exit, the flag is restored to its previous value by the context manager.
|
||||
|
||||
```python
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||
batch = Tensor(X_train[samp], requires_grad=False)
|
||||
# get the corresponding labels
|
||||
labels = Tensor(Y_train[samp])
|
||||
with Tensor.train():
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||
batch = Tensor(X_train[samp], requires_grad=False)
|
||||
# get the corresponding labels
|
||||
labels = Tensor(Y_train[samp])
|
||||
|
||||
# forward pass
|
||||
out = net(batch)
|
||||
# forward pass
|
||||
out = net(batch)
|
||||
|
||||
# compute loss
|
||||
loss = sparse_categorical_crossentropy(out, labels)
|
||||
# compute loss
|
||||
loss = sparse_categorical_crossentropy(out, labels)
|
||||
|
||||
# zero gradients
|
||||
opt.zero_grad()
|
||||
# zero gradients
|
||||
opt.zero_grad()
|
||||
|
||||
# backward pass
|
||||
loss.backward()
|
||||
# backward pass
|
||||
loss.backward()
|
||||
|
||||
# update parameters
|
||||
opt.step()
|
||||
# update parameters
|
||||
opt.step()
|
||||
|
||||
# calculate accuracy
|
||||
pred = out.argmax(axis=-1)
|
||||
acc = (pred == labels).mean()
|
||||
# calculate accuracy
|
||||
pred = out.argmax(axis=-1)
|
||||
acc = (pred == labels).mean()
|
||||
|
||||
if step % 100 == 0:
|
||||
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
|
||||
if step % 100 == 0:
|
||||
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
@@ -215,9 +213,6 @@ Now that we have trained our neural network we can evaluate it on the test set.
|
||||
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
|
||||
|
||||
```python
|
||||
# set training flag to false
|
||||
Tensor.training = False
|
||||
|
||||
with Timing("Time: "):
|
||||
avg_acc = 0
|
||||
for step in range(1000):
|
||||
|
||||
@@ -22,7 +22,7 @@ from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.dist import collectives
|
||||
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 100)
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features):
|
||||
@@ -234,7 +234,6 @@ def train_cifar():
|
||||
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
Tensor.training = True
|
||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
@@ -312,8 +311,7 @@ def train_cifar():
|
||||
eval_step_jitted = TinyJit(eval_step)
|
||||
eval_step_ema_jitted = TinyJit(eval_step)
|
||||
|
||||
# 97 steps in 2 seconds = 20ms / step Tensor.training = True
|
||||
|
||||
# 97 steps in 2 seconds = 20ms / step
|
||||
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
||||
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
|
||||
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
|
||||
@@ -327,77 +325,78 @@ def train_cifar():
|
||||
best_eval = -1
|
||||
i = 0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
while i <= STEPS:
|
||||
if i%100 == 0 and i > 1:
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
corrects_ema = []
|
||||
losses = []
|
||||
losses_ema = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||
with Tensor.train():
|
||||
while i <= STEPS:
|
||||
if i%100 == 0 and i > 1:
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
corrects_ema = []
|
||||
losses = []
|
||||
losses_ema = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
corrects.extend(correct.numpy().tolist())
|
||||
if model_ema:
|
||||
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
||||
losses_ema.append(loss_ema.numpy().tolist())
|
||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||
|
||||
# collect accuracy across ranks
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||
if getenv("DIST"):
|
||||
if rank == 0:
|
||||
for j in range(1, min(world_size, 5)):
|
||||
if model_ema:
|
||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
||||
else:
|
||||
recv_sum, recv_len = OOB.recv(j)
|
||||
correct_sum += recv_sum
|
||||
correct_len += recv_len
|
||||
if model_ema:
|
||||
correct_sum_ema += recv_sum_ema
|
||||
correct_len_ema += recv_len_ema
|
||||
elif rank < min(world_size, 5):
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
corrects.extend(correct.numpy().tolist())
|
||||
if model_ema:
|
||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
||||
else:
|
||||
OOB.send((correct_sum, correct_len), 0)
|
||||
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
||||
losses_ema.append(loss_ema.numpy().tolist())
|
||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||
|
||||
# only rank 0 prints
|
||||
if rank == 0:
|
||||
acc = correct_sum/correct_len*100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
if acc > best_eval:
|
||||
best_eval = acc
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
# collect accuracy across ranks
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||
if getenv("DIST"):
|
||||
if rank == 0:
|
||||
for j in range(1, min(world_size, 5)):
|
||||
if model_ema:
|
||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
||||
else:
|
||||
recv_sum, recv_len = OOB.recv(j)
|
||||
correct_sum += recv_sum
|
||||
correct_len += recv_len
|
||||
if model_ema:
|
||||
correct_sum_ema += recv_sum_ema
|
||||
correct_len_ema += recv_len_ema
|
||||
elif rank < min(world_size, 5):
|
||||
if model_ema:
|
||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
||||
else:
|
||||
OOB.send((correct_sum, correct_len), 0)
|
||||
|
||||
if STEPS == 0 or i==STEPS: break
|
||||
X, Y = next(batcher)
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
GlobalCounters.reset()
|
||||
st = time.monotonic()
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||
if model_ema is None:
|
||||
model_ema = modelEMA(W, model)
|
||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||
cl = time.monotonic()
|
||||
if not getenv("DIST"):
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
else:
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
i += 1
|
||||
# only rank 0 prints
|
||||
if rank == 0:
|
||||
acc = correct_sum/correct_len*100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
if acc > best_eval:
|
||||
best_eval = acc
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
|
||||
if STEPS == 0 or i==STEPS: break
|
||||
X, Y = next(batcher)
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
GlobalCounters.reset()
|
||||
st = time.monotonic()
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||
if model_ema is None:
|
||||
model_ema = modelEMA(W, model)
|
||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||
cl = time.monotonic()
|
||||
if not getenv("DIST"):
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
else:
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
i += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not getenv("DIST"):
|
||||
|
||||
@@ -26,12 +26,11 @@ def train_maskrcnn():
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.training = True
|
||||
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
globals()[nm]()
|
||||
with Tensor.train():
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
globals()[nm]()
|
||||
|
||||
|
||||
|
||||
@@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.training = True
|
||||
with Tensor.train():
|
||||
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
|
||||
@@ -61,43 +61,43 @@ if __name__ == "__main__":
|
||||
else:
|
||||
X_train, Y_train = fetch_cifar()
|
||||
|
||||
Tensor.training = True
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train[samp], Y_train[samp]
|
||||
with Tensor.train()
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train[samp], Y_train[samp]
|
||||
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
fp_time = (time.time()-st)*1000.0
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
fp_time = (time.time()-st)*1000.0
|
||||
|
||||
y = np.zeros((BS,classes), np.float32)
|
||||
y[range(y.shape[0]),Y] = -classes
|
||||
y = Tensor(y, requires_grad=False)
|
||||
loss = out.log_softmax().mul(y).mean()
|
||||
y = np.zeros((BS,classes), np.float32)
|
||||
y[range(y.shape[0]),Y] = -classes
|
||||
y = Tensor(y, requires_grad=False)
|
||||
loss = out.log_softmax().mul(y).mean()
|
||||
|
||||
optimizer.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
bp_time = (time.time()-st)*1000.0
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
bp_time = (time.time()-st)*1000.0
|
||||
|
||||
st = time.time()
|
||||
optimizer.step()
|
||||
opt_time = (time.time()-st)*1000.0
|
||||
st = time.time()
|
||||
optimizer.step()
|
||||
opt_time = (time.time()-st)*1000.0
|
||||
|
||||
st = time.time()
|
||||
loss = loss.numpy()
|
||||
cat = out.argmax(axis=1).numpy()
|
||||
accuracy = (cat == Y).mean()
|
||||
finish_time = (time.time()-st)*1000.0
|
||||
st = time.time()
|
||||
loss = loss.numpy()
|
||||
cat = out.argmax(axis=1).numpy()
|
||||
accuracy = (cat == Y).mean()
|
||||
finish_time = (time.time()-st)*1000.0
|
||||
|
||||
# printing
|
||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
||||
(loss, accuracy,
|
||||
fp_time, bp_time, opt_time, finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time))
|
||||
# printing
|
||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
||||
(loss, accuracy,
|
||||
fp_time, bp_time, opt_time, finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time))
|
||||
|
||||
del out, y, loss
|
||||
del out, y, loss
|
||||
|
||||
@@ -5,31 +5,31 @@ from tinygrad.helpers import getenv
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||
Tensor.training = True
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=getenv('CI', False))):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||
y = Tensor(target_transform(Y_train[samp]))
|
||||
with Tensor.train():
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=getenv('CI', False))):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||
y = Tensor(target_transform(Y_train[samp]))
|
||||
|
||||
# network
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
# network
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
|
||||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if noloss: del loss
|
||||
optim.step()
|
||||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if noloss: del loss
|
||||
optim.step()
|
||||
|
||||
# printing
|
||||
if not noloss:
|
||||
cat = out.argmax(axis=-1)
|
||||
accuracy = (cat == y).mean().numpy()
|
||||
# printing
|
||||
if not noloss:
|
||||
cat = out.argmax(axis=-1)
|
||||
accuracy = (cat == y).mean().numpy()
|
||||
|
||||
loss = loss.detach().numpy()
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
loss = loss.detach().numpy()
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
return [losses, accuracies]
|
||||
|
||||
|
||||
|
||||
130
test/external/external_test_opt.py
vendored
130
test/external/external_test_opt.py
vendored
@@ -171,73 +171,68 @@ class TestOpt(unittest.TestCase):
|
||||
assert ret.numpy()[0] == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(get_parameters(c1))
|
||||
with CLCache():
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(get_parameters(c1))
|
||||
with CLCache():
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2]))
|
||||
with CLCache(allowed=9):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2]))
|
||||
with CLCache(allowed=9):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_4convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
|
||||
with CLCache(allowed=19):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
|
||||
with CLCache(allowed=19):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(get_parameters([c1, bn]))
|
||||
with CLCache(allowed=18): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(get_parameters([c1, bn]))
|
||||
with CLCache(allowed=18): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
@@ -250,15 +245,14 @@ class TestOpt(unittest.TestCase):
|
||||
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
|
||||
|
||||
def test_fold_conv_batchnorm(self):
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
|
||||
17
test/external/graph_batchnorm.py
vendored
17
test/external/graph_batchnorm.py
vendored
@@ -4,15 +4,14 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
|
||||
def model_step(lm):
|
||||
Tensor.training = True
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
optimizer.step()
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
optimizer.step()
|
||||
|
||||
class TestBatchnorm(unittest.TestCase):
|
||||
def test_conv(self):
|
||||
|
||||
@@ -9,50 +9,50 @@ from extra.datasets import fetch_mnist
|
||||
from tinygrad.helpers import CI
|
||||
|
||||
def compare_tiny_torch(model, model_torch, X, Y):
|
||||
Tensor.training = True
|
||||
model_torch.train()
|
||||
model_state_dict = get_state_dict(model)
|
||||
for k,v in model_torch.named_parameters():
|
||||
if not CI: print(f"initting {k} from torch")
|
||||
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
||||
with Tensor.train():
|
||||
model_torch.train()
|
||||
model_state_dict = get_state_dict(model)
|
||||
for k,v in model_torch.named_parameters():
|
||||
if not CI: print(f"initting {k} from torch")
|
||||
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
||||
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01)
|
||||
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01)
|
||||
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
|
||||
|
||||
Xt = torch.Tensor(X.numpy())
|
||||
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
|
||||
Xt = torch.Tensor(X.numpy())
|
||||
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
|
||||
|
||||
out = model(X)
|
||||
loss = (out * Y).mean()
|
||||
if not CI: print(loss.realize().numpy())
|
||||
out = model(X)
|
||||
loss = (out * Y).mean()
|
||||
if not CI: print(loss.realize().numpy())
|
||||
|
||||
out_torch = model_torch(torch.Tensor(X.numpy()))
|
||||
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
|
||||
if not CI: print(loss_torch.detach().numpy())
|
||||
out_torch = model_torch(torch.Tensor(X.numpy()))
|
||||
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
|
||||
if not CI: print(loss_torch.detach().numpy())
|
||||
|
||||
# assert losses match
|
||||
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
|
||||
# assert losses match
|
||||
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
|
||||
|
||||
# zero and backward
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_torch.zero_grad()
|
||||
loss_torch.backward()
|
||||
# zero and backward
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_torch.zero_grad()
|
||||
loss_torch.backward()
|
||||
|
||||
for k,v in list(model_torch.named_parameters())[::-1]:
|
||||
g = model_state_dict[k].grad.numpy()
|
||||
gt = v.grad.detach().numpy()
|
||||
if not CI: print("testing grads", k)
|
||||
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
|
||||
for k,v in list(model_torch.named_parameters())[::-1]:
|
||||
g = model_state_dict[k].grad.numpy()
|
||||
gt = v.grad.detach().numpy()
|
||||
if not CI: print("testing grads", k)
|
||||
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
|
||||
|
||||
# take the steps
|
||||
optimizer.step()
|
||||
optimizer_torch.step()
|
||||
# take the steps
|
||||
optimizer.step()
|
||||
optimizer_torch.step()
|
||||
|
||||
# assert weights match (they don't!)
|
||||
for k,v in model_torch.named_parameters():
|
||||
if not CI: print("testing weight", k)
|
||||
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
|
||||
# assert weights match (they don't!)
|
||||
for k,v in model_torch.named_parameters():
|
||||
if not CI: print("testing weight", k)
|
||||
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
|
||||
|
||||
def get_mnist_data():
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
@@ -122,27 +122,24 @@ class TestRealWorld(unittest.TestCase):
|
||||
#Device.DEFAULT = "FAKE"
|
||||
#Device['fake'].codegen = Device[old_default].codegen
|
||||
|
||||
# TODO: with train
|
||||
old_training = Tensor.training
|
||||
Tensor.training = True
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||
with Tensor.train():
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||
|
||||
BS = 32 if CI else 512
|
||||
BS = 32 if CI else 512
|
||||
|
||||
@TinyJit
|
||||
def train(X):
|
||||
out = model(X)
|
||||
loss = out.mean()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@TinyJit
|
||||
def train(X):
|
||||
out = model(X)
|
||||
loss = out.mean()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
|
||||
|
||||
# reset device
|
||||
Tensor.training = old_training
|
||||
#Device.DEFAULT = old_default
|
||||
# reset device
|
||||
#Device.DEFAULT = old_default
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
#@unittest.skip("may want to reconsider this")
|
||||
def test_fold_batchnorm(self):
|
||||
Tensor.training = True
|
||||
img = Tensor.empty(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
out = bn(img)
|
||||
check_schedule(out, 3)
|
||||
Tensor.training = False
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
out = bn(img)
|
||||
check_schedule(out, 3)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
||||
@@ -57,11 +57,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_attention_training(self):
|
||||
Tensor.training = True
|
||||
self.test_attention(dropout_p=0.0)
|
||||
with self.assertRaises(AssertionError):
|
||||
# symbolic shape dropout is not supported
|
||||
self.test_attention(dropout_p=0.5)
|
||||
with Tensor.train():
|
||||
self.test_attention(dropout_p=0.0)
|
||||
with self.assertRaises(AssertionError):
|
||||
# symbolic shape dropout is not supported
|
||||
self.test_attention(dropout_p=0.5)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
|
||||
@@ -97,12 +97,12 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert W.grad is not None
|
||||
|
||||
def test_dropout(self):
|
||||
Tensor.training = True
|
||||
n, rate = 1_000_000, 0.1
|
||||
w = Tensor.ones(n).dropout(rate)
|
||||
non_zeros = np.count_nonzero(w.numpy())
|
||||
expected = n * (1 - rate)
|
||||
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
|
||||
with Tensor.train():
|
||||
n, rate = 1_000_000, 0.1
|
||||
w = Tensor.ones(n).dropout(rate)
|
||||
non_zeros = np.count_nonzero(w.numpy())
|
||||
expected = n * (1 - rate)
|
||||
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
|
||||
|
||||
def test_jacobian(self):
|
||||
W = np.random.RandomState(42069).random((10, 5)).astype(np.float32)
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -38,6 +38,12 @@ class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
training: ClassVar[bool] = False
|
||||
class train:
|
||||
def __enter__(self):
|
||||
self.prev = Tensor.training
|
||||
Tensor.training = True
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
|
||||
Reference in New Issue
Block a user