mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
* fast resnet eval * fix HIP multidevice graph * neater expression for devices * lines * add decorator test
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
# for imagenet download prepare.sh and run it
|
|
import glob, random, json
|
|
import numpy as np
|
|
from PIL import Image
|
|
import functools, pathlib
|
|
from tinygrad.helpers import DEBUG, diskcache
|
|
|
|
BASEDIR = pathlib.Path(__file__).parent / "imagenet"
|
|
|
|
@functools.lru_cache(None)
|
|
def get_imagenet_categories():
|
|
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
|
return {v[0]: int(k) for k,v in ci.items()}
|
|
|
|
@diskcache
|
|
def get_train_files(): return glob.glob(str(BASEDIR / "train/*/*"))
|
|
|
|
@functools.lru_cache(None)
|
|
def get_val_files(): return glob.glob(str(BASEDIR / "val/*/*"))
|
|
|
|
def image_load(fn):
|
|
import torchvision.transforms.functional as F
|
|
img = Image.open(fn).convert('RGB')
|
|
img = F.resize(img, 256, Image.BILINEAR)
|
|
img = F.center_crop(img, 224)
|
|
ret = np.array(img)
|
|
return ret
|
|
|
|
def iterate(bs=32, val=True, shuffle=True):
|
|
cir = get_imagenet_categories()
|
|
files = get_val_files() if val else get_train_files()
|
|
order = list(range(0, len(files)))
|
|
if DEBUG >= 1: print(f"imagenet size {len(order)}")
|
|
if shuffle: random.shuffle(order)
|
|
from multiprocessing import Pool
|
|
p = Pool(16)
|
|
for i in range(0, len(files), bs):
|
|
X = p.map(image_load, [files[i] for i in order[i:i+bs]])
|
|
Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]]
|
|
yield (np.array(X), np.array(Y))
|
|
|
|
def fetch_batch(bs, val=False):
|
|
cir = get_imagenet_categories()
|
|
files = get_val_files() if val else get_train_files()
|
|
samp = np.random.randint(0, len(files), size=(bs))
|
|
files = [files[i] for i in samp]
|
|
X = [image_load(x) for x in files]
|
|
Y = [cir[x.split("/")[0]] for x in files]
|
|
return np.array(X), np.array(Y)
|
|
|
|
if __name__ == "__main__":
|
|
X,Y = fetch_batch(64)
|
|
print(X.shape, Y)
|
|
|