import unittest, time import numpy as np from tinygrad import Device from tinygrad.nn.state import get_parameters from tinygrad.nn import optim from tinygrad.helpers import getenv from test.helpers import slow from extra.training import train from extra.models.convnext import ConvNeXt from extra.models.efficientnet import EfficientNet from extra.models.transformer import Transformer from extra.models.vit import ViT from extra.models.resnet import ResNet18 BS = getenv("BS", 2) def train_one_step(model,X,Y): params = get_parameters(model) pcount = 0 for p in params: pcount += np.prod(p.shape) optimizer = optim.SGD(params, lr=0.001) print("stepping %r with %.1fM params bs %d" % (type(model), pcount/1e6, BS)) st = time.time() train(model, X, Y, optimizer, steps=1, BS=BS) et = time.time()-st print("done in %.2f ms" % (et*1000.)) def check_gc(): if Device.DEFAULT == "CL": from extra.introspection import print_objects assert print_objects() == 0 class TestTrain(unittest.TestCase): def test_convnext(self): model = ConvNeXt(depths=[1], dims=[16]) X = np.zeros((BS,3,224,224), dtype=np.float32) Y = np.zeros((BS), dtype=np.int32) train_one_step(model,X,Y) check_gc() @slow def test_efficientnet(self): model = EfficientNet(0) X = np.zeros((BS,3,224,224), dtype=np.float32) Y = np.zeros((BS), dtype=np.int32) train_one_step(model,X,Y) check_gc() @slow def test_vit(self): model = ViT() X = np.zeros((BS,3,224,224), dtype=np.float32) Y = np.zeros((BS,), dtype=np.int32) train_one_step(model,X,Y) check_gc() @slow def test_transformer(self): # this should be small GPT-2, but the param count is wrong # (real ff_dim is 768*4) model = Transformer(syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12, ff_dim=768//4) X = np.zeros((BS,6), dtype=np.float32) Y = np.zeros((BS,6), dtype=np.int32) train_one_step(model,X,Y) check_gc() @slow def test_resnet(self): X = np.zeros((BS, 3, 224, 224), dtype=np.float32) Y = np.zeros((BS), dtype=np.int32) for resnet_v in [ResNet18]: model = resnet_v() model.load_from_pretrained() train_one_step(model, X, Y) check_gc() def test_bert(self): # TODO: write this pass if __name__ == '__main__': unittest.main()