pad first batch of imagenet dataloader and update eval (#4368)

* pad first batch of imagenet dataloader and update eval

* pad zero instead of empty for training
This commit is contained in:
chenyu
2024-05-01 00:21:52 -04:00
committed by GitHub
parent 4a26718ca9
commit 683b7c605a
2 changed files with 46 additions and 25 deletions

View File

@@ -5,7 +5,7 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Timing, Context
from tinygrad.helpers import getenv, prod, Context, round_up
from collections import deque
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
@@ -44,21 +44,25 @@ def loader_process(q_in, q_out, X:Tensor, seed):
with Context(DEBUG=0):
while (_recv := q_in.get()) is not None:
idx, fn, val = _recv
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
if fn is not None:
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
if val:
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
# sudo apt-get install libjpeg-dev
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
img = center_crop(img)
img = np.array(img)
if val:
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
# sudo apt-get install libjpeg-dev
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
img = center_crop(img)
img = np.array(img)
else:
# reseed rng for determinism
if seed is not None:
np.random.seed(seed * 2 ** 20 + idx)
random.seed(seed * 2 ** 20 + idx)
img = preprocess_train(img)
else:
# reseed rng for determinism
if seed is not None:
np.random.seed(seed * 2 ** 20 + idx)
random.seed(seed * 2 ** 20 + idx)
img = preprocess_train(img)
# pad zeros
img = np.zeros((224, 224, 3), dtype=np.uint8)
# broken out
#img_tensor = Tensor(img.tobytes(), device='CPU')
@@ -73,19 +77,35 @@ def loader_process(q_in, q_out, X:Tensor, seed):
q_out.put(idx)
q_out.put(None)
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_first_batch=False):
from extra.datasets.imagenet import get_train_files, get_val_files
files = get_val_files() if val else get_train_files()
from extra.datasets.imagenet import get_imagenet_categories
cir = get_imagenet_categories()
BATCH_COUNT = min(32, len(files) // batch_size)
gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
if pad_first_batch:
FIRST_BATCH_PAD = round_up(len(files), batch_size) - len(files)
else:
FIRST_BATCH_PAD = 0
file_count = FIRST_BATCH_PAD + len(files)
BATCH_COUNT = min(32, file_count // batch_size)
def _gen():
for _ in range(FIRST_BATCH_PAD): yield -1
yield from shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
gen = iter(_gen())
def enqueue_batch(num):
for idx in range(num*batch_size, (num+1)*batch_size):
fn = files[next(gen)]
q_in.put((idx, fn, val))
Y[idx] = cir[fn.split("/")[-2]]
fidx = next(gen)
if fidx != -1:
fn = files[fidx]
q_in.put((idx, fn, val))
Y[idx] = cir[fn.split("/")[-2]]
else:
# padding
q_in.put((idx, None, val))
Y[idx] = -1
shutdown = False
class Cookie:
@@ -126,8 +146,8 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
for bn in range(BATCH_COUNT): enqueue_batch(bn)
# NOTE: this is batch aligned, last ones are ignored
for _ in range(0, len(files)//batch_size): yield receive_batch()
# NOTE: this is batch aligned, last ones are ignored unless pad_first_batch is True
for _ in range(0, file_count//batch_size): yield receive_batch()
finally:
shutdown = True
# empty queues

View File

@@ -35,7 +35,7 @@ def eval_resnet():
# evaluation on the mlperf classes of the validation set from imagenet
from examples.mlperf.dataloader import batch_load_resnet
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False)
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False, pad_first_batch=True)
def data_get():
x,y,cookie = next(iterator)
return x.shard(GPUS, axis=0).realize(), y, cookie
@@ -51,9 +51,10 @@ def eval_resnet():
try: next_proc = data_get()
except StopIteration: next_proc = None
nd = time.perf_counter()
proc = proc[0].numpy() == proc[1] # this realizes the models and frees the cookies
y = np.array(proc[1])
proc = (proc[0].numpy() == y) & (y != -1) # this realizes the models and frees the cookies
n += proc.sum()
d += proc.size
d += (y != -1).sum()
et = time.perf_counter()
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
st = et