import traceback import time from multiprocessing import Process, Queue import numpy as np from tqdm import trange import tinygrad.nn.optim as optim from tinygrad.helpers import getenv from tinygrad.tensor import Tensor from datasets import fetch_cifar from datasets.imagenet import fetch_batch from extra.utils import get_parameters from models.efficientnet import EfficientNet class TinyConvNet: def __init__(self, classes=10): conv = 3 inter_chan, out_chan = 8, 16 # for speed self.c1 = Tensor.uniform(inter_chan,3,conv,conv) self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) self.l1 = Tensor.uniform(out_chan*6*6, classes) def forward(self, x): x = x.conv2d(self.c1).relu().max_pool2d() x = x.conv2d(self.c2).relu().max_pool2d() x = x.reshape(shape=[x.shape[0], -1]) return x.dot(self.l1) if __name__ == "__main__": IMAGENET = getenv("IMAGENET") classes = 1000 if IMAGENET else 10 TINY = getenv("TINY") TRANSFER = getenv("TRANSFER") if TINY: model = TinyConvNet(classes) elif TRANSFER: model = EfficientNet(getenv("NUM", 0), classes, has_se=True) model.load_from_pretrained() else: model = EfficientNet(getenv("NUM", 0), classes, has_se=False) parameters = get_parameters(model) print("parameter count", len(parameters)) optimizer = optim.Adam(parameters, lr=0.001) BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048) print(f"training with batch size {BS} for {steps} steps") if IMAGENET: def loader(q): while 1: try: q.put(fetch_batch(BS)) except Exception: traceback.print_exc() q = Queue(16) for i in range(2): p = Process(target=loader, args=(q,)) p.daemon = True p.start() else: X_train, Y_train = fetch_cifar() Tensor.training = True for i in (t := trange(steps)): if IMAGENET: X, Y = q.get(True) else: samp = np.random.randint(0, X_train.shape[0], size=(BS)) X, Y = X_train[samp], Y_train[samp] st = time.time() out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) fp_time = (time.time()-st)*1000.0 y = np.zeros((BS,classes), np.float32) y[range(y.shape[0]),Y] = -classes y = Tensor(y, requires_grad=False) loss = out.logsoftmax().mul(y).mean() optimizer.zero_grad() st = time.time() loss.backward() bp_time = (time.time()-st)*1000.0 st = time.time() optimizer.step() opt_time = (time.time()-st)*1000.0 st = time.time() loss = loss.cpu().data cat = np.argmax(out.cpu().data, axis=1) accuracy = (cat == Y).mean() finish_time = (time.time()-st)*1000.0 # printing t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % (loss, accuracy, fp_time, bp_time, opt_time, finish_time, fp_time + bp_time + opt_time + finish_time)) del out, y, loss