diff --git a/examples/hlb_cifar10_torch.py b/examples/hlb_cifar10_torch.py new file mode 100644 index 0000000000..7e5bcf1f9e --- /dev/null +++ b/examples/hlb_cifar10_torch.py @@ -0,0 +1,101 @@ +import numpy as np +import torch +import time +from torch import nn +from torch import optim + +from datasets import fetch_cifar +from tinygrad.helpers import getenv + +num_classes = 10 +class ConvGroup(nn.Module): + def __init__(self, channels_in, channels_out, short, se=True): + super().__init__() + self.short, self.se = short, se and not short + self.conv = nn.ModuleList([nn.Conv2d(channels_in if i == 0 else channels_out, channels_out, kernel_size=3, padding=1, bias=False) for i in range(1 if short else 3)]) + self.norm = nn.ModuleList([nn.BatchNorm2d(channels_out, track_running_stats=False, eps=1e-12, momentum=0.8) for _ in range(1 if short else 3)]) + if self.se: self.se1, self.se2 = nn.Linear(channels_out, channels_out//16), nn.Linear(channels_out//16, channels_out) + + def forward(self, x): + x = nn.functional.max_pool2d(self.conv[0](x), 2) + x = self.norm[0](x).relu() + if self.short: return x + residual = x + mult = self.se2(self.se1(residual.mean((2,3))).relu()).sigmoid().reshape(x.shape[0], x.shape[1], 1, 1) if self.se else 1.0 + x = self.norm[1](self.conv[1](x)).relu() + x = self.norm[2](self.conv[2](x) * mult).relu() + return x + residual + +class SpeedyResNet(nn.Module): + def __init__(self): + super().__init__() + # TODO: add whitening + self.ic = nn.Conv2d(3, 64, kernel_size=1) + self.ib = nn.BatchNorm2d(64, track_running_stats=False, eps=1e-12, momentum=0.8) + self.net = nn.ModuleList([ + ConvGroup(64, 128, short=False), + ConvGroup(128, 256, short=True), + ConvGroup(256, 512, short=False), + ]) + self.lin = nn.Linear(512, num_classes, bias=False) + + # note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of logsoftmax + def forward(self, x): + x = self.ic(x) + x = self.ib(x) + x = x.relu() + for layer in self.net: + x = layer(x) + x = torch.amax(x, dim=(2,3)) + x = self.lin(x) + return x.log_softmax(-1) + +def train_step_jitted(model, optimizer, X, Y): + out = model(X) + loss = (out * Y).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + return loss + +def fetch_batch(X_train, Y_train, BS): + # fetch a batch + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + Y = np.zeros((BS, num_classes), np.float32) + Y[range(BS),Y_train[samp]] = -1.0*num_classes + X = torch.tensor(X_train[samp]) + Y = torch.tensor(Y.reshape(BS, num_classes)) + return X.cuda(), Y.cuda() + +def train_cifar(): + BS = getenv("BS", 512) + if getenv("FAKEDATA"): + N = 2048 + X_train = np.random.default_rng().standard_normal(size=(N, 3, 32, 32), dtype=np.float32) + Y_train = np.random.randint(0,10,size=(N), dtype=np.int32) + X_test, Y_test = X_train, Y_train + else: + X_train,Y_train = fetch_cifar(train=True) + X_test,Y_test = fetch_cifar(train=False) + print(X_train.shape, Y_train.shape) + Xt, Yt = fetch_batch(X_test, Y_test, BS=BS) + + model = SpeedyResNet().cuda() + optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.85, nesterov=True) + X, Y = fetch_batch(X_train, Y_train, BS=BS) + for i in range(getenv("STEPS", 10)): + if i%10 == 0: + # use training batchnorm (and no_grad would change the kernels) + outs = model(Xt).detach().cpu().numpy().argmax(axis=1) + correct = outs == Yt.detach().cpu().numpy().argmin(axis=1) + print(f"eval {sum(correct)}/{len(correct)} {sum(correct)/len(correct)*100.0:.2f}%") + st = time.monotonic() + loss = train_step_jitted(model, optimizer, X, Y) + et = time.monotonic() + X, Y = fetch_batch(X_train, Y_train, BS=BS) # do this here + loss_cpu = loss.detach().cpu().item() + cl = time.monotonic() + print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss") + +if __name__ == "__main__": + train_cifar() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 6acad44b5b..fcf257b0f1 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,5 +1,6 @@ from tinygrad.tensor import Tensor +# TODO: BatchNorm2D -> BatchNorm2d class BatchNorm2D: def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): assert affine, "BatchNorm2D is only supported with affine" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 48e4cb93e2..ca4c88d72b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -278,6 +278,7 @@ class Tensor: _, e, ss = self._softmax() return e.div(ss) + # TODO: logsoftmax -> log_softmax and add dim param def logsoftmax(self): m, _, ss = self._softmax() return m - ss.log()