trainer works with CIFAR

This commit is contained in:
George Hotz
2020-12-06 12:20:14 -08:00
parent 80a9c777ba
commit 609d11e699
2 changed files with 58 additions and 20 deletions

View File

@@ -3,31 +3,64 @@ import time
import numpy as np
from extra.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from tinygrad.utils import get_parameters
from tinygrad.utils import get_parameters, fetch
from tinygrad.utils import layer_init_uniform
from tqdm import trange
import tinygrad.optim as optim
import io
import tarfile
import pickle
class TinyConvNet:
def __init__(self, classes=10):
conv = 3
inter_chan, out_chan = 8, 16 # for speed
self.c1 = Tensor(layer_init_uniform(inter_chan,3,conv,conv))
self.c2 = Tensor(layer_init_uniform(out_chan,inter_chan,conv,conv))
self.l1 = Tensor(layer_init_uniform(out_chan*6*6, classes))
def forward(self, x):
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).logsoftmax()
def load_cifar():
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
db = pickle.load(tt.extractfile('cifar-10-batches-py/data_batch_1'), encoding="bytes")
X = db[b'data'].reshape((-1, 3, 32, 32))
Y = np.array(db[b'labels'])
return X, Y
if __name__ == "__main__":
X_train, Y_train = load_cifar()
classes = 10
Tensor.default_gpu = os.getenv("GPU") is not None
model = EfficientNet(int(os.getenv("NUM", "0")))
TINY = os.getenv("TINY") is not None
if TINY:
model = TinyConvNet(classes)
else:
model = EfficientNet(int(os.getenv("NUM", "0")), classes)
parameters = get_parameters(model)
print(len(parameters))
print("parameters", len(parameters))
optimizer = optim.Adam(parameters, lr=0.001)
BS = 16
img = np.zeros((BS,3,224,224), dtype=np.float32)
#BS, steps = 16, 32
BS, steps = 64 if TINY else 16, 1024
for i in range(32):
print("running batch %d, %d tensors allocated" % (i, Tensor.allocated))
for i in (t := trange(steps)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
img = X_train[samp].astype(np.float32)
st = time.time()
out = model.forward(Tensor(img))
et = time.time()
print("forward %.2f s" % (et-st))
fp_time = (time.time()-st)*1000.0
Y = [0]*BS
y = np.zeros((BS,1000), np.float32)
y[range(y.shape[0]),Y] = -1000.0
Y = Y_train[samp]
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
y = Tensor(y)
loss = out.logsoftmax().mul(y).mean()
@@ -35,13 +68,18 @@ if __name__ == "__main__":
st = time.time()
loss.backward()
et = time.time()
print("backward %.2f s" % (et-st))
bp_time = (time.time()-st)*1000.0
st = time.time()
optimizer.step()
et = time.time()
print("optimizer %.2f s" % (et-st))
opt_time = (time.time()-st)*1000.0
cat = np.argmax(out.cpu().data, axis=1)
accuracy = (cat == Y).mean()
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f %.2f %.2f -- %d" %
(loss.cpu().data, accuracy, fp_time, bp_time, opt_time, Tensor.allocated))
del out, y, loss

View File

@@ -116,7 +116,7 @@ class MBConvBlock:
return x
class EfficientNet:
def __init__(self, number=0):
def __init__(self, number=0, classes=1000):
self.number = number
global_params = [
# width, depth
@@ -171,8 +171,8 @@ class EfficientNet:
out_channels = round_filters(1280)
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
self._bn1 = BatchNorm2D(out_channels)
self._fc = Tensor.zeros(out_channels, 1000)
self._fc_bias = Tensor.zeros(1000)
self._fc = Tensor.zeros(out_channels, classes)
self._fc_bias = Tensor.zeros(classes)
def forward(self, x):
x = x.pad2d(padding=(0,1,0,1))