mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fast imagenet eval, gets 76.14% across the set
This commit is contained in:
@@ -27,12 +27,10 @@ def image_load(fn):
|
||||
img = Image.open(fn).convert('RGB')
|
||||
img = F.resize(img, 256, Image.BILINEAR)
|
||||
img = F.center_crop(img, 224)
|
||||
img = F.to_tensor(img)
|
||||
img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
|
||||
ret = np.array(img, dtype='float32')
|
||||
ret = np.array(img)
|
||||
return ret
|
||||
|
||||
def iterate(bs, val=False, shuffle=True):
|
||||
def iterate(bs=32, val=True, shuffle=True):
|
||||
files = get_val_files() if val else get_train_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle: random.shuffle(order)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
@@ -7,21 +8,39 @@ if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
# Resnet50-v1.5
|
||||
from tinygrad.jit import TinyJit
|
||||
from models.resnet import ResNet50
|
||||
mdl = ResNet50()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
def input_fixup(x):
|
||||
x = x.permute([0,3,1,2]) / 255.0
|
||||
x -= input_mean
|
||||
x /= input_std
|
||||
return x
|
||||
|
||||
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
||||
|
||||
# evaluation on the mlperf classes of the validation set from imagenet
|
||||
from datasets.imagenet import iterate
|
||||
from extra.helpers import cross_process
|
||||
|
||||
n,d = 0,0
|
||||
for x,y in iterate(32, True, shuffle=True):
|
||||
st = time.perf_counter()
|
||||
for x,y in cross_process(iterate):
|
||||
dat = Tensor(x.astype(np.float32))
|
||||
outs = mdl(dat)
|
||||
mt = time.perf_counter()
|
||||
outs = mdlrun(dat)
|
||||
t = outs.numpy().argmax(axis=1)
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
print(t)
|
||||
print(y)
|
||||
n += (t==y).sum()
|
||||
d += len(t)
|
||||
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
|
||||
@@ -22,3 +22,18 @@ def enable_early_exec():
|
||||
qin.put(x)
|
||||
return qout.get()
|
||||
return early_exec
|
||||
|
||||
def proc(itermaker, q):
|
||||
for x in itermaker(): q.put(x)
|
||||
q.close()
|
||||
|
||||
def cross_process(itermaker, maxsize=8):
|
||||
# TODO: use cloudpickle for itermaker
|
||||
import multiprocessing
|
||||
q = multiprocessing.Queue(maxsize)
|
||||
p = multiprocessing.Process(target=proc, args=(itermaker, q))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
# TODO: write tests and handle exit case
|
||||
while 1: yield q.get()
|
||||
|
||||
Reference in New Issue
Block a user