import traceback import time from multiprocessing import Process, Queue import numpy as np from tqdm import trange from tinygrad.nn import optim from tinygrad.helpers import getenv from tinygrad.tensor import Tensor from datasets import fetch_cifar from datasets.imagenet import fetch_batch from models.efficientnet import EfficientNet class TinyConvNet: def __init__(self, classes=10): conv = 3 inter_chan, out_chan = 8, 16 # for speed self.c1 = Tensor.uniform(inter_chan,3,conv,conv) self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) self.l1 = Tensor.uniform(out_chan*6*6, classes) def forward(self, x): x = x.conv2d(self.c1).relu().max_pool2d() x = x.conv2d(self.c2).relu().max_pool2d() x = x.reshape(shape=[x.shape[0], -1]) return x.dot(self.l1) if __name__ == "__main__": IMAGENET = getenv("IMAGENET") classes = 1000 if IMAGENET else 10 TINY = getenv("TINY") TRANSFER = getenv("TRANSFER") if TINY: model = TinyConvNet(classes) elif TRANSFER: model = EfficientNet(getenv("NUM", 0), classes, has_se=True) model.load_from_pretrained() else: model = EfficientNet(getenv("NUM", 0), classes, has_se=False) parameters = optim.get_parameters(model) print("parameter count", len(parameters)) optimizer = optim.Adam(parameters, lr=0.001) BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048) print(f"training with batch size {BS} for {steps} steps") if IMAGENET: def loader(q): while 1: try: q.put(fetch_batch(BS)) except Exception: traceback.print_exc() q = Queue(16) for i in range(2): p = Process(target=loader, args=(q,)) p.daemon = True p.start() else: X_train, Y_train = fetch_cifar() Tensor.training = True for i in (t := trange(steps)): if IMAGENET: X, Y = q.get(True) else: samp = np.random.randint(0, X_train.shape[0], size=(BS)) X, Y = X_train[samp], Y_train[samp] st = time.time() out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) fp_time = (time.time()-st)*1000.0 y = np.zeros((BS,classes), np.float32) y[range(y.shape[0]),Y] = -classes y = Tensor(y, requires_grad=False) loss = out.logsoftmax().mul(y).mean() optimizer.zero_grad() st = time.time() loss.backward() bp_time = (time.time()-st)*1000.0 st = time.time() optimizer.step() opt_time = (time.time()-st)*1000.0 st = time.time() loss = loss.cpu().numpy() cat = np.argmax(out.cpu().numpy(), axis=1) accuracy = (cat == Y).mean() finish_time = (time.time()-st)*1000.0 # printing t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % (loss, accuracy, fp_time, bp_time, opt_time, finish_time, fp_time + bp_time + opt_time + finish_time)) del out, y, loss