mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
with Tensor.train() (#1935)
* add with.train * remove the rest TODOs * fix pyflake * fix pyflake error * fix mypy
This commit is contained in:
@@ -131,12 +131,6 @@ Training neural networks in tinygrad is super simple.
|
|||||||
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
|
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
|
||||||
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
|
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
|
||||||
|
|
||||||
First we need to set the training flag in `Tensor`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Tensor.training = True
|
|
||||||
```
|
|
||||||
|
|
||||||
For our loss function we will be using sparse categorical cross entropy loss.
|
For our loss function we will be using sparse categorical cross entropy loss.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -176,37 +170,41 @@ from extra.datasets import fetch_mnist
|
|||||||
Now we have everything we need to start training our neural network.
|
Now we have everything we need to start training our neural network.
|
||||||
We will be training for 1000 steps with a batch size of 64.
|
We will be training for 1000 steps with a batch size of 64.
|
||||||
|
|
||||||
|
We use `with Tensor.train()` set the internal flag `Tensor.training` to `True` during training.
|
||||||
|
Upon exit, the flag is restored to its previous value by the context manager.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||||
|
|
||||||
for step in range(1000):
|
with Tensor.train():
|
||||||
# random sample a batch
|
for step in range(1000):
|
||||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
# random sample a batch
|
||||||
batch = Tensor(X_train[samp], requires_grad=False)
|
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||||
# get the corresponding labels
|
batch = Tensor(X_train[samp], requires_grad=False)
|
||||||
labels = Tensor(Y_train[samp])
|
# get the corresponding labels
|
||||||
|
labels = Tensor(Y_train[samp])
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
out = net(batch)
|
out = net(batch)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss = sparse_categorical_crossentropy(out, labels)
|
loss = sparse_categorical_crossentropy(out, labels)
|
||||||
|
|
||||||
# zero gradients
|
# zero gradients
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
|
|
||||||
# backward pass
|
# backward pass
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# update parameters
|
# update parameters
|
||||||
opt.step()
|
opt.step()
|
||||||
|
|
||||||
# calculate accuracy
|
# calculate accuracy
|
||||||
pred = out.argmax(axis=-1)
|
pred = out.argmax(axis=-1)
|
||||||
acc = (pred == labels).mean()
|
acc = (pred == labels).mean()
|
||||||
|
|
||||||
if step % 100 == 0:
|
if step % 100 == 0:
|
||||||
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
|
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Evaluation
|
## Evaluation
|
||||||
@@ -215,9 +213,6 @@ Now that we have trained our neural network we can evaluate it on the test set.
|
|||||||
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
|
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# set training flag to false
|
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
with Timing("Time: "):
|
with Timing("Time: "):
|
||||||
avg_acc = 0
|
avg_acc = 0
|
||||||
for step in range(1000):
|
for step in range(1000):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from extra.lr_scheduler import OneCycleLR
|
|||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from extra.dist import collectives
|
from extra.dist import collectives
|
||||||
|
|
||||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 100)
|
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||||
|
|
||||||
class BatchNorm(nn.BatchNorm2d):
|
class BatchNorm(nn.BatchNorm2d):
|
||||||
def __init__(self, num_features):
|
def __init__(self, num_features):
|
||||||
@@ -234,7 +234,6 @@ def train_cifar():
|
|||||||
|
|
||||||
# this import needs to be done here because this is running in a subprocess
|
# this import needs to be done here because this is running in a subprocess
|
||||||
from extra.dist import OOB
|
from extra.dist import OOB
|
||||||
Tensor.training = True
|
|
||||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||||
|
|
||||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||||
@@ -312,8 +311,7 @@ def train_cifar():
|
|||||||
eval_step_jitted = TinyJit(eval_step)
|
eval_step_jitted = TinyJit(eval_step)
|
||||||
eval_step_ema_jitted = TinyJit(eval_step)
|
eval_step_ema_jitted = TinyJit(eval_step)
|
||||||
|
|
||||||
# 97 steps in 2 seconds = 20ms / step Tensor.training = True
|
# 97 steps in 2 seconds = 20ms / step
|
||||||
|
|
||||||
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
||||||
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
|
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
|
||||||
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
|
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
|
||||||
@@ -327,77 +325,78 @@ def train_cifar():
|
|||||||
best_eval = -1
|
best_eval = -1
|
||||||
i = 0
|
i = 0
|
||||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||||
while i <= STEPS:
|
with Tensor.train():
|
||||||
if i%100 == 0 and i > 1:
|
while i <= STEPS:
|
||||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
if i%100 == 0 and i > 1:
|
||||||
corrects = []
|
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||||
corrects_ema = []
|
corrects = []
|
||||||
losses = []
|
corrects_ema = []
|
||||||
losses_ema = []
|
losses = []
|
||||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
losses_ema = []
|
||||||
# further split batch if distributed
|
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||||
if getenv("DIST"):
|
# further split batch if distributed
|
||||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
if getenv("DIST"):
|
||||||
|
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||||
|
|
||||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||||
losses.append(loss.numpy().tolist())
|
losses.append(loss.numpy().tolist())
|
||||||
corrects.extend(correct.numpy().tolist())
|
corrects.extend(correct.numpy().tolist())
|
||||||
if model_ema:
|
|
||||||
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
|
||||||
losses_ema.append(loss_ema.numpy().tolist())
|
|
||||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
|
||||||
|
|
||||||
# collect accuracy across ranks
|
|
||||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
|
||||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
|
||||||
if getenv("DIST"):
|
|
||||||
if rank == 0:
|
|
||||||
for j in range(1, min(world_size, 5)):
|
|
||||||
if model_ema:
|
|
||||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
|
||||||
else:
|
|
||||||
recv_sum, recv_len = OOB.recv(j)
|
|
||||||
correct_sum += recv_sum
|
|
||||||
correct_len += recv_len
|
|
||||||
if model_ema:
|
|
||||||
correct_sum_ema += recv_sum_ema
|
|
||||||
correct_len_ema += recv_len_ema
|
|
||||||
elif rank < min(world_size, 5):
|
|
||||||
if model_ema:
|
if model_ema:
|
||||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
||||||
else:
|
losses_ema.append(loss_ema.numpy().tolist())
|
||||||
OOB.send((correct_sum, correct_len), 0)
|
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||||
|
|
||||||
# only rank 0 prints
|
# collect accuracy across ranks
|
||||||
if rank == 0:
|
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||||
acc = correct_sum/correct_len*100.0
|
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
if getenv("DIST"):
|
||||||
if acc > best_eval:
|
if rank == 0:
|
||||||
best_eval = acc
|
for j in range(1, min(world_size, 5)):
|
||||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
|
if model_ema:
|
||||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
||||||
|
else:
|
||||||
|
recv_sum, recv_len = OOB.recv(j)
|
||||||
|
correct_sum += recv_sum
|
||||||
|
correct_len += recv_len
|
||||||
|
if model_ema:
|
||||||
|
correct_sum_ema += recv_sum_ema
|
||||||
|
correct_len_ema += recv_len_ema
|
||||||
|
elif rank < min(world_size, 5):
|
||||||
|
if model_ema:
|
||||||
|
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
||||||
|
else:
|
||||||
|
OOB.send((correct_sum, correct_len), 0)
|
||||||
|
|
||||||
if STEPS == 0 or i==STEPS: break
|
# only rank 0 prints
|
||||||
X, Y = next(batcher)
|
if rank == 0:
|
||||||
# further split batch if distributed
|
acc = correct_sum/correct_len*100.0
|
||||||
if getenv("DIST"):
|
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
if acc > best_eval:
|
||||||
GlobalCounters.reset()
|
best_eval = acc
|
||||||
st = time.monotonic()
|
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
|
||||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||||
et = time.monotonic()
|
|
||||||
loss_cpu = loss.numpy()
|
if STEPS == 0 or i==STEPS: break
|
||||||
# EMA for network weights
|
X, Y = next(batcher)
|
||||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
# further split batch if distributed
|
||||||
if model_ema is None:
|
if getenv("DIST"):
|
||||||
model_ema = modelEMA(W, model)
|
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
GlobalCounters.reset()
|
||||||
cl = time.monotonic()
|
st = time.monotonic()
|
||||||
if not getenv("DIST"):
|
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
et = time.monotonic()
|
||||||
else:
|
loss_cpu = loss.numpy()
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
# EMA for network weights
|
||||||
i += 1
|
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||||
|
if model_ema is None:
|
||||||
|
model_ema = modelEMA(W, model)
|
||||||
|
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||||
|
cl = time.monotonic()
|
||||||
|
if not getenv("DIST"):
|
||||||
|
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||||
|
else:
|
||||||
|
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||||
|
i += 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if not getenv("DIST"):
|
if not getenv("DIST"):
|
||||||
|
|||||||
@@ -26,12 +26,11 @@ def train_maskrcnn():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
|
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
nm = f"train_{m}"
|
||||||
nm = f"train_{m}"
|
if nm in globals():
|
||||||
if nm in globals():
|
print(f"training {m}")
|
||||||
print(f"training {m}")
|
globals()[nm]()
|
||||||
globals()[nm]()
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d
|
|||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
|
|
||||||
BS, C1, H, W = 4, 16, 224, 224
|
BS, C1, H, W = 4, 16, 224, 224
|
||||||
C2, K, S, P = 64, 7, 2, 1
|
C2, K, S, P = 64, 7, 2, 1
|
||||||
|
|
||||||
x = Tensor.uniform(BS, C1, H, W)
|
x = Tensor.uniform(BS, C1, H, W)
|
||||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||||
for t in get_parameters([x, conv, bn]): t.realize()
|
for t in get_parameters([x, conv, bn]): t.realize()
|
||||||
|
|
||||||
print("running network")
|
print("running network")
|
||||||
x.sequential([conv, bn]).numpy()
|
x.sequential([conv, bn]).numpy()
|
||||||
|
|||||||
@@ -61,43 +61,43 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
X_train, Y_train = fetch_cifar()
|
X_train, Y_train = fetch_cifar()
|
||||||
|
|
||||||
Tensor.training = True
|
with Tensor.train()
|
||||||
for i in (t := trange(steps)):
|
for i in (t := trange(steps)):
|
||||||
if IMAGENET:
|
if IMAGENET:
|
||||||
X, Y = q.get(True)
|
X, Y = q.get(True)
|
||||||
else:
|
else:
|
||||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||||
X, Y = X_train[samp], Y_train[samp]
|
X, Y = X_train[samp], Y_train[samp]
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||||
fp_time = (time.time()-st)*1000.0
|
fp_time = (time.time()-st)*1000.0
|
||||||
|
|
||||||
y = np.zeros((BS,classes), np.float32)
|
y = np.zeros((BS,classes), np.float32)
|
||||||
y[range(y.shape[0]),Y] = -classes
|
y[range(y.shape[0]),Y] = -classes
|
||||||
y = Tensor(y, requires_grad=False)
|
y = Tensor(y, requires_grad=False)
|
||||||
loss = out.log_softmax().mul(y).mean()
|
loss = out.log_softmax().mul(y).mean()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
bp_time = (time.time()-st)*1000.0
|
bp_time = (time.time()-st)*1000.0
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
opt_time = (time.time()-st)*1000.0
|
opt_time = (time.time()-st)*1000.0
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
loss = loss.numpy()
|
loss = loss.numpy()
|
||||||
cat = out.argmax(axis=1).numpy()
|
cat = out.argmax(axis=1).numpy()
|
||||||
accuracy = (cat == Y).mean()
|
accuracy = (cat == Y).mean()
|
||||||
finish_time = (time.time()-st)*1000.0
|
finish_time = (time.time()-st)*1000.0
|
||||||
|
|
||||||
# printing
|
# printing
|
||||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
||||||
(loss, accuracy,
|
(loss, accuracy,
|
||||||
fp_time, bp_time, opt_time, finish_time,
|
fp_time, bp_time, opt_time, finish_time,
|
||||||
fp_time + bp_time + opt_time + finish_time))
|
fp_time + bp_time + opt_time + finish_time))
|
||||||
|
|
||||||
del out, y, loss
|
del out, y, loss
|
||||||
|
|||||||
@@ -5,31 +5,31 @@ from tinygrad.helpers import getenv
|
|||||||
|
|
||||||
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
|
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
|
||||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
losses, accuracies = [], []
|
losses, accuracies = [], []
|
||||||
for i in (t := trange(steps, disable=getenv('CI', False))):
|
for i in (t := trange(steps, disable=getenv('CI', False))):
|
||||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||||
y = Tensor(target_transform(Y_train[samp]))
|
y = Tensor(target_transform(Y_train[samp]))
|
||||||
|
|
||||||
# network
|
# network
|
||||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||||
|
|
||||||
loss = lossfn(out, y)
|
loss = lossfn(out, y)
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if noloss: del loss
|
if noloss: del loss
|
||||||
optim.step()
|
optim.step()
|
||||||
|
|
||||||
# printing
|
# printing
|
||||||
if not noloss:
|
if not noloss:
|
||||||
cat = out.argmax(axis=-1)
|
cat = out.argmax(axis=-1)
|
||||||
accuracy = (cat == y).mean().numpy()
|
accuracy = (cat == y).mean().numpy()
|
||||||
|
|
||||||
loss = loss.detach().numpy()
|
loss = loss.detach().numpy()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
accuracies.append(accuracy)
|
accuracies.append(accuracy)
|
||||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||||
return [losses, accuracies]
|
return [losses, accuracies]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
130
test/external/external_test_opt.py
vendored
130
test/external/external_test_opt.py
vendored
@@ -171,73 +171,68 @@ class TestOpt(unittest.TestCase):
|
|||||||
assert ret.numpy()[0] == 33
|
assert ret.numpy()[0] == 33
|
||||||
|
|
||||||
def test_fold_batchnorm(self):
|
def test_fold_batchnorm(self):
|
||||||
# TODO: with Tensor.training
|
with Tensor.train():
|
||||||
Tensor.training = True
|
img = Tensor.ones(1,32,4,4)
|
||||||
img = Tensor.ones(1,32,4,4)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
with CLCache():
|
||||||
with CLCache():
|
img_bn = bn(img).realize()
|
||||||
img_bn = bn(img).realize()
|
print(img_bn)
|
||||||
print(img_bn)
|
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
# Tensor.training = False
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_conv_sgd(self):
|
def test_fold_conv_sgd(self):
|
||||||
# TODO: with Tensor.training
|
with Tensor.train():
|
||||||
Tensor.training = True
|
img = Tensor.ones(2,3,4,4)
|
||||||
img = Tensor.ones(2,3,4,4)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
opt = optim.SGD(get_parameters(c1))
|
||||||
opt = optim.SGD(get_parameters(c1))
|
with CLCache():
|
||||||
with CLCache():
|
opt.zero_grad()
|
||||||
opt.zero_grad()
|
c1(img).relu().sum().backward()
|
||||||
c1(img).relu().sum().backward()
|
opt.step()
|
||||||
opt.step()
|
# TODO: this should be 4, but the sum output child stays around
|
||||||
# TODO: this should be 4, but the sum output child stays around
|
# with pushing_permutes it can be 3
|
||||||
# with pushing_permutes it can be 3
|
# TODO: broken with optim fixes
|
||||||
# TODO: broken with optim fixes
|
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
# Tensor.training = False
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_2convs_sgd(self):
|
def test_fold_2convs_sgd(self):
|
||||||
# TODO: with Tensor.training
|
with Tensor.train():
|
||||||
Tensor.training = True
|
img = Tensor.ones(2,3,64,64)
|
||||||
img = Tensor.ones(2,3,64,64)
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
opt = optim.SGD(get_parameters([c1, c2]))
|
||||||
opt = optim.SGD(get_parameters([c1, c2]))
|
with CLCache(allowed=9):
|
||||||
with CLCache(allowed=9):
|
opt.zero_grad()
|
||||||
opt.zero_grad()
|
c2(c1(img).relu()).relu().sum().backward()
|
||||||
c2(c1(img).relu()).relu().sum().backward()
|
opt.step()
|
||||||
opt.step()
|
# Tensor.training = False
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_4convs_sgd(self):
|
def test_fold_4convs_sgd(self):
|
||||||
# TODO: with Tensor.training
|
with Tensor.train():
|
||||||
Tensor.training = True
|
img = Tensor.ones(2,3,64,64)
|
||||||
img = Tensor.ones(2,3,64,64)
|
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
|
||||||
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
|
with CLCache(allowed=19):
|
||||||
with CLCache(allowed=19):
|
opt.zero_grad()
|
||||||
opt.zero_grad()
|
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
opt.step()
|
||||||
opt.step()
|
# Tensor.training = False
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_conv_batchnorm_sgd(self):
|
def test_fold_conv_batchnorm_sgd(self):
|
||||||
# TODO: with Tensor.training
|
with Tensor.train():
|
||||||
Tensor.training = True
|
img = Tensor.ones(1,3,4,4)
|
||||||
img = Tensor.ones(1,3,4,4)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
opt = optim.SGD(get_parameters([c1, bn]))
|
||||||
opt = optim.SGD(get_parameters([c1, bn]))
|
with CLCache(allowed=18): # this is too high
|
||||||
with CLCache(allowed=18): # this is too high
|
img_bn = bn(c1(img)).elu().sum()
|
||||||
img_bn = bn(c1(img)).elu().sum()
|
opt.zero_grad()
|
||||||
opt.zero_grad()
|
img_bn.backward()
|
||||||
img_bn.backward()
|
opt.step()
|
||||||
opt.step()
|
# Tensor.training = False
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_conv_batchnorm_notrain(self):
|
def test_fold_conv_batchnorm_notrain(self):
|
||||||
img = Tensor.ones(1,3,8,8)
|
img = Tensor.ones(1,3,8,8)
|
||||||
@@ -250,15 +245,14 @@ class TestOpt(unittest.TestCase):
|
|||||||
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
|
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
|
||||||
|
|
||||||
def test_fold_conv_batchnorm(self):
|
def test_fold_conv_batchnorm(self):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
img = Tensor.ones(1,3,8,8)
|
img = Tensor.ones(1,3,8,8)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
with CLCache():
|
with CLCache():
|
||||||
img_conv = bn(c1(img)).relu().realize()
|
img_conv = bn(c1(img)).relu().realize()
|
||||||
print(img_conv)
|
print(img_conv)
|
||||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_conv_elu(self):
|
def test_fold_conv_elu(self):
|
||||||
img = Tensor.ones(1,4,8,8)
|
img = Tensor.ones(1,4,8,8)
|
||||||
|
|||||||
17
test/external/graph_batchnorm.py
vendored
17
test/external/graph_batchnorm.py
vendored
@@ -4,15 +4,14 @@ from tinygrad.tensor import Tensor
|
|||||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||||
|
|
||||||
def model_step(lm):
|
def model_step(lm):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||||
loss = lm.forward(x).sum()
|
loss = lm.forward(x).sum()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
del x,loss
|
del x,loss
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
class TestBatchnorm(unittest.TestCase):
|
class TestBatchnorm(unittest.TestCase):
|
||||||
def test_conv(self):
|
def test_conv(self):
|
||||||
|
|||||||
@@ -9,50 +9,50 @@ from extra.datasets import fetch_mnist
|
|||||||
from tinygrad.helpers import CI
|
from tinygrad.helpers import CI
|
||||||
|
|
||||||
def compare_tiny_torch(model, model_torch, X, Y):
|
def compare_tiny_torch(model, model_torch, X, Y):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
model_torch.train()
|
model_torch.train()
|
||||||
model_state_dict = get_state_dict(model)
|
model_state_dict = get_state_dict(model)
|
||||||
for k,v in model_torch.named_parameters():
|
for k,v in model_torch.named_parameters():
|
||||||
if not CI: print(f"initting {k} from torch")
|
if not CI: print(f"initting {k} from torch")
|
||||||
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
||||||
|
|
||||||
optimizer = optim.SGD(get_parameters(model), lr=0.01)
|
optimizer = optim.SGD(get_parameters(model), lr=0.01)
|
||||||
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
|
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
|
||||||
|
|
||||||
Xt = torch.Tensor(X.numpy())
|
Xt = torch.Tensor(X.numpy())
|
||||||
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
|
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
|
||||||
|
|
||||||
out = model(X)
|
out = model(X)
|
||||||
loss = (out * Y).mean()
|
loss = (out * Y).mean()
|
||||||
if not CI: print(loss.realize().numpy())
|
if not CI: print(loss.realize().numpy())
|
||||||
|
|
||||||
out_torch = model_torch(torch.Tensor(X.numpy()))
|
out_torch = model_torch(torch.Tensor(X.numpy()))
|
||||||
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
|
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
|
||||||
if not CI: print(loss_torch.detach().numpy())
|
if not CI: print(loss_torch.detach().numpy())
|
||||||
|
|
||||||
# assert losses match
|
# assert losses match
|
||||||
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
|
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
|
||||||
|
|
||||||
# zero and backward
|
# zero and backward
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer_torch.zero_grad()
|
optimizer_torch.zero_grad()
|
||||||
loss_torch.backward()
|
loss_torch.backward()
|
||||||
|
|
||||||
for k,v in list(model_torch.named_parameters())[::-1]:
|
for k,v in list(model_torch.named_parameters())[::-1]:
|
||||||
g = model_state_dict[k].grad.numpy()
|
g = model_state_dict[k].grad.numpy()
|
||||||
gt = v.grad.detach().numpy()
|
gt = v.grad.detach().numpy()
|
||||||
if not CI: print("testing grads", k)
|
if not CI: print("testing grads", k)
|
||||||
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
|
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
|
||||||
|
|
||||||
# take the steps
|
# take the steps
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer_torch.step()
|
optimizer_torch.step()
|
||||||
|
|
||||||
# assert weights match (they don't!)
|
# assert weights match (they don't!)
|
||||||
for k,v in model_torch.named_parameters():
|
for k,v in model_torch.named_parameters():
|
||||||
if not CI: print("testing weight", k)
|
if not CI: print("testing weight", k)
|
||||||
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
|
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
|
||||||
|
|
||||||
def get_mnist_data():
|
def get_mnist_data():
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||||
|
|||||||
@@ -122,27 +122,24 @@ class TestRealWorld(unittest.TestCase):
|
|||||||
#Device.DEFAULT = "FAKE"
|
#Device.DEFAULT = "FAKE"
|
||||||
#Device['fake'].codegen = Device[old_default].codegen
|
#Device['fake'].codegen = Device[old_default].codegen
|
||||||
|
|
||||||
# TODO: with train
|
with Tensor.train():
|
||||||
old_training = Tensor.training
|
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||||
Tensor.training = True
|
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
|
||||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
|
||||||
|
|
||||||
BS = 32 if CI else 512
|
BS = 32 if CI else 512
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def train(X):
|
def train(X):
|
||||||
out = model(X)
|
out = model(X)
|
||||||
loss = out.mean()
|
loss = out.mean()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
|
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
|
||||||
|
|
||||||
# reset device
|
# reset device
|
||||||
Tensor.training = old_training
|
#Device.DEFAULT = old_default
|
||||||
#Device.DEFAULT = old_default
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -154,12 +154,11 @@ class TestSchedule(unittest.TestCase):
|
|||||||
|
|
||||||
#@unittest.skip("may want to reconsider this")
|
#@unittest.skip("may want to reconsider this")
|
||||||
def test_fold_batchnorm(self):
|
def test_fold_batchnorm(self):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
img = Tensor.empty(1,32,4,4)
|
img = Tensor.empty(1,32,4,4)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
out = bn(img)
|
out = bn(img)
|
||||||
check_schedule(out, 3)
|
check_schedule(out, 3)
|
||||||
Tensor.training = False
|
|
||||||
|
|
||||||
def test_fold_conv_relu(self):
|
def test_fold_conv_relu(self):
|
||||||
c1 = nn.Conv2d(3,16,3)
|
c1 = nn.Conv2d(3,16,3)
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ class TestSymbolicOps(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
def test_attention_training(self):
|
def test_attention_training(self):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
self.test_attention(dropout_p=0.0)
|
self.test_attention(dropout_p=0.0)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
# symbolic shape dropout is not supported
|
# symbolic shape dropout is not supported
|
||||||
self.test_attention(dropout_p=0.5)
|
self.test_attention(dropout_p=0.5)
|
||||||
|
|
||||||
def test_cat_dim0(self):
|
def test_cat_dim0(self):
|
||||||
def f(a, b): return a.cat(b, dim=0).realize()
|
def f(a, b): return a.cat(b, dim=0).realize()
|
||||||
|
|||||||
@@ -97,12 +97,12 @@ class TestTinygrad(unittest.TestCase):
|
|||||||
assert W.grad is not None
|
assert W.grad is not None
|
||||||
|
|
||||||
def test_dropout(self):
|
def test_dropout(self):
|
||||||
Tensor.training = True
|
with Tensor.train():
|
||||||
n, rate = 1_000_000, 0.1
|
n, rate = 1_000_000, 0.1
|
||||||
w = Tensor.ones(n).dropout(rate)
|
w = Tensor.ones(n).dropout(rate)
|
||||||
non_zeros = np.count_nonzero(w.numpy())
|
non_zeros = np.count_nonzero(w.numpy())
|
||||||
expected = n * (1 - rate)
|
expected = n * (1 - rate)
|
||||||
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
|
np.testing.assert_allclose(non_zeros, expected, rtol=2e-3)
|
||||||
|
|
||||||
def test_jacobian(self):
|
def test_jacobian(self):
|
||||||
W = np.random.RandomState(42069).random((10, 5)).astype(np.float32)
|
W = np.random.RandomState(42069).random((10, 5)).astype(np.float32)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
|||||||
from functools import partialmethod, reduce
|
from functools import partialmethod, reduce
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any
|
||||||
|
|
||||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
|
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
@@ -38,6 +38,12 @@ class Tensor:
|
|||||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||||
__deletable__ = ('_ctx',)
|
__deletable__ = ('_ctx',)
|
||||||
training: ClassVar[bool] = False
|
training: ClassVar[bool] = False
|
||||||
|
class train:
|
||||||
|
def __enter__(self):
|
||||||
|
self.prev = Tensor.training
|
||||||
|
Tensor.training = True
|
||||||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
|
||||||
|
|
||||||
no_grad: ClassVar[bool] = False
|
no_grad: ClassVar[bool] = False
|
||||||
default_type: ClassVar[DType] = dtypes.float32
|
default_type: ClassVar[DType] = dtypes.float32
|
||||||
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user