fast imagenet eval, gets 76.14% across the set

This commit is contained in:
George Hotz
2023-05-13 12:39:01 -07:00
parent c552f6f92b
commit e0b2035023
3 changed files with 38 additions and 6 deletions

View File

@@ -27,12 +27,10 @@ def image_load(fn):
img = Image.open(fn).convert('RGB')
img = F.resize(img, 256, Image.BILINEAR)
img = F.center_crop(img, 224)
img = F.to_tensor(img)
img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
ret = np.array(img, dtype='float32')
ret = np.array(img)
return ret
def iterate(bs, val=False, shuffle=True):
def iterate(bs=32, val=True, shuffle=True):
files = get_val_files() if val else get_train_files()
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)

View File

@@ -1,3 +1,4 @@
import time
import numpy as np
from tinygrad.tensor import Tensor
@@ -7,21 +8,39 @@ if __name__ == "__main__":
Tensor.no_grad = True
# Resnet50-v1.5
from tinygrad.jit import TinyJit
from models.resnet import ResNet50
mdl = ResNet50()
mdl.load_from_pretrained()
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
def input_fixup(x):
x = x.permute([0,3,1,2]) / 255.0
x -= input_mean
x /= input_std
return x
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
# evaluation on the mlperf classes of the validation set from imagenet
from datasets.imagenet import iterate
from extra.helpers import cross_process
n,d = 0,0
for x,y in iterate(32, True, shuffle=True):
st = time.perf_counter()
for x,y in cross_process(iterate):
dat = Tensor(x.astype(np.float32))
outs = mdl(dat)
mt = time.perf_counter()
outs = mdlrun(dat)
t = outs.numpy().argmax(axis=1)
et = time.perf_counter()
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
print(t)
print(y)
n += (t==y).sum()
d += len(t)
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
st = time.perf_counter()

View File

@@ -22,3 +22,18 @@ def enable_early_exec():
qin.put(x)
return qout.get()
return early_exec
def proc(itermaker, q):
for x in itermaker(): q.put(x)
q.close()
def cross_process(itermaker, maxsize=8):
# TODO: use cloudpickle for itermaker
import multiprocessing
q = multiprocessing.Queue(maxsize)
p = multiprocessing.Process(target=proc, args=(itermaker, q))
p.daemon = True
p.start()
# TODO: write tests and handle exit case
while 1: yield q.get()