import traceback import time from multiprocessing import Process, Queue import numpy as np from tqdm import trange from tinygrad.nn import optim from tinygrad.helpers import getenv from tinygrad.tensor import Tensor from datasets import fetch_cifar from models.efficientnet import EfficientNet class TinyConvNet: def __init__(self, classes=10): conv = 3 inter_chan, out_chan = 8, 16 # for speed self.c1 = Tensor.uniform(inter_chan,3,conv,conv) self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) self.l1 = Tensor.uniform(out_chan*6*6, classes) def forward(self, x): x = x.conv2d(self.c1).relu().max_pool2d() x = x.conv2d(self.c2).relu().max_pool2d() x = x.reshape(shape=[x.shape[0], -1]) return x.dot(self.l1) if __name__ == "__main__": IMAGENET = getenv("IMAGENET") classes = 1000 if IMAGENET else 10 TINY = getenv("TINY") TRANSFER = getenv("TRANSFER") if TINY: model = TinyConvNet(classes) elif TRANSFER: model = EfficientNet(getenv("NUM", 0), classes, has_se=True) model.load_from_pretrained() else: model = EfficientNet(getenv("NUM", 0), classes, has_se=False) parameters = optim.get_parameters(model) print("parameter count", len(parameters)) optimizer = optim.Adam(parameters, lr=0.001) BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048) print(f"training with batch size {BS} for {steps} steps") if IMAGENET: from datasets.imagenet import fetch_batch def loader(q): while 1: try: q.put(fetch_batch(BS)) except Exception: traceback.print_exc() q = Queue(16) for i in range(2): p = Process(target=loader, args=(q,)) p.daemon = True p.start() else: X_train, Y_train = fetch_cifar() Tensor.training = True for i in (t := trange(steps)): if IMAGENET: X, Y = q.get(True) else: samp = np.random.randint(0, X_train.shape[0], size=(BS)) X, Y = X_train[samp], Y_train[samp] st = time.time() out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) fp_time = (time.time()-st)*1000.0 y = np.zeros((BS,classes), np.float32) y[range(y.shape[0]),Y] = -classes y = Tensor(y, requires_grad=False) loss = out.log_softmax().mul(y).mean() optimizer.zero_grad() st = time.time() loss.backward() bp_time = (time.time()-st)*1000.0 st = time.time() optimizer.step() opt_time = (time.time()-st)*1000.0 st = time.time() loss = loss.cpu().numpy() cat = np.argmax(out.cpu().numpy(), axis=1) accuracy = (cat == Y).mean() finish_time = (time.time()-st)*1000.0 # printing t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % (loss, accuracy, fp_time, bp_time, opt_time, finish_time, fp_time + bp_time + opt_time + finish_time)) del out, y, loss