mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
trainer works with CIFAR
This commit is contained in:
@@ -3,31 +3,64 @@ import time
|
||||
import numpy as np
|
||||
from extra.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import get_parameters
|
||||
from tinygrad.utils import get_parameters, fetch
|
||||
from tinygrad.utils import layer_init_uniform
|
||||
from tqdm import trange
|
||||
import tinygrad.optim as optim
|
||||
import io
|
||||
import tarfile
|
||||
import pickle
|
||||
|
||||
class TinyConvNet:
|
||||
def __init__(self, classes=10):
|
||||
conv = 3
|
||||
inter_chan, out_chan = 8, 16 # for speed
|
||||
self.c1 = Tensor(layer_init_uniform(inter_chan,3,conv,conv))
|
||||
self.c2 = Tensor(layer_init_uniform(out_chan,inter_chan,conv,conv))
|
||||
self.l1 = Tensor(layer_init_uniform(out_chan*6*6, classes))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
def load_cifar():
|
||||
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
|
||||
db = pickle.load(tt.extractfile('cifar-10-batches-py/data_batch_1'), encoding="bytes")
|
||||
X = db[b'data'].reshape((-1, 3, 32, 32))
|
||||
Y = np.array(db[b'labels'])
|
||||
return X, Y
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train = load_cifar()
|
||||
classes = 10
|
||||
|
||||
Tensor.default_gpu = os.getenv("GPU") is not None
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")))
|
||||
TINY = os.getenv("TINY") is not None
|
||||
if TINY:
|
||||
model = TinyConvNet(classes)
|
||||
else:
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")), classes)
|
||||
parameters = get_parameters(model)
|
||||
print(len(parameters))
|
||||
print("parameters", len(parameters))
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
|
||||
BS = 16
|
||||
img = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
#BS, steps = 16, 32
|
||||
BS, steps = 64 if TINY else 16, 1024
|
||||
|
||||
for i in range(32):
|
||||
print("running batch %d, %d tensors allocated" % (i, Tensor.allocated))
|
||||
for i in (t := trange(steps)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
img = X_train[samp].astype(np.float32)
|
||||
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(img))
|
||||
et = time.time()
|
||||
print("forward %.2f s" % (et-st))
|
||||
fp_time = (time.time()-st)*1000.0
|
||||
|
||||
Y = [0]*BS
|
||||
|
||||
y = np.zeros((BS,1000), np.float32)
|
||||
y[range(y.shape[0]),Y] = -1000.0
|
||||
Y = Y_train[samp]
|
||||
y = np.zeros((BS,classes), np.float32)
|
||||
y[range(y.shape[0]),Y] = -classes
|
||||
y = Tensor(y)
|
||||
loss = out.logsoftmax().mul(y).mean()
|
||||
|
||||
@@ -35,13 +68,18 @@ if __name__ == "__main__":
|
||||
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
et = time.time()
|
||||
print("backward %.2f s" % (et-st))
|
||||
bp_time = (time.time()-st)*1000.0
|
||||
|
||||
st = time.time()
|
||||
optimizer.step()
|
||||
et = time.time()
|
||||
print("optimizer %.2f s" % (et-st))
|
||||
opt_time = (time.time()-st)*1000.0
|
||||
|
||||
cat = np.argmax(out.cpu().data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
# printing
|
||||
t.set_description("loss %.2f accuracy %.2f -- %.2f %.2f %.2f -- %d" %
|
||||
(loss.cpu().data, accuracy, fp_time, bp_time, opt_time, Tensor.allocated))
|
||||
|
||||
del out, y, loss
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ class MBConvBlock:
|
||||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self, number=0):
|
||||
def __init__(self, number=0, classes=1000):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
@@ -171,8 +171,8 @@ class EfficientNet:
|
||||
out_channels = round_filters(1280)
|
||||
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
|
||||
self._bn1 = BatchNorm2D(out_channels)
|
||||
self._fc = Tensor.zeros(out_channels, 1000)
|
||||
self._fc_bias = Tensor.zeros(1000)
|
||||
self._fc = Tensor.zeros(out_channels, classes)
|
||||
self._fc_bias = Tensor.zeros(classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.pad2d(padding=(0,1,0,1))
|
||||
|
||||
Reference in New Issue
Block a user