Files
tinygrad/examples/train_efficientnet.py
2022-01-15 19:57:41 -08:00

92 lines
2.5 KiB
Python

import os
import time
import numpy as np
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from extra.utils import get_parameters
from tqdm import trange
from tinygrad.nn import BatchNorm2D
import tinygrad.optim as optim
from datasets import fetch_cifar
class TinyConvNet:
def __init__(self, classes=10):
conv = 3
inter_chan, out_chan = 8, 16 # for speed
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
self.bn1 = BatchNorm2D(inter_chan)
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
self.bn2 = BatchNorm2D(out_chan)
self.l1 = Tensor.uniform(out_chan*6*6, classes)
def forward(self, x):
x = x.conv2d(self.c1).relu().max_pool2d()
x = self.bn1(x)
x = x.conv2d(self.c2).relu().max_pool2d()
x = self.bn2(x)
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).logsoftmax()
if __name__ == "__main__":
X_train, Y_train = fetch_cifar()
classes = 10
TINY = os.getenv("TINY") is not None
TRANSFER = os.getenv("TRANSFER") is not None
if TINY:
model = TinyConvNet(classes)
elif TRANSFER:
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=True)
model.load_from_pretrained()
else:
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=False)
parameters = get_parameters(model)
print("parameter count", len(parameters))
optimizer = optim.Adam(parameters, lr=0.001)
BS, steps = int(os.getenv("BS", "64" if TINY else "16")), 2048
print("training with batch size %d for %d steps" % (BS, steps))
Tensor.training = True
for i in (t := trange(steps)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
img = X_train[samp].astype(np.float32)
st = time.time()
out = model.forward(Tensor(img))
fp_time = (time.time()-st)*1000.0
Y = Y_train[samp]
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
y = Tensor(y)
loss = out.logsoftmax().mul(y).mean()
optimizer.zero_grad()
st = time.time()
loss.backward()
bp_time = (time.time()-st)*1000.0
st = time.time()
optimizer.step()
opt_time = (time.time()-st)*1000.0
#print(out.cpu().data)
st = time.time()
loss = loss.cpu().data
cat = np.argmax(out.cpu().data, axis=1)
accuracy = (cat == Y).mean()
finish_time = (time.time()-st)*1000.0
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
(loss, accuracy,
fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time))
del out, y, loss