mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
pad first batch of imagenet dataloader and update eval (#4368)
* pad first batch of imagenet dataloader and update eval * pad zero instead of empty for training
This commit is contained in:
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, prod, Timing, Context
|
||||
from tinygrad.helpers import getenv, prod, Context, round_up
|
||||
from collections import deque
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
|
||||
|
||||
@@ -44,21 +44,25 @@ def loader_process(q_in, q_out, X:Tensor, seed):
|
||||
with Context(DEBUG=0):
|
||||
while (_recv := q_in.get()) is not None:
|
||||
idx, fn, val = _recv
|
||||
img = Image.open(fn)
|
||||
img = img.convert('RGB') if img.mode != "RGB" else img
|
||||
if fn is not None:
|
||||
img = Image.open(fn)
|
||||
img = img.convert('RGB') if img.mode != "RGB" else img
|
||||
|
||||
if val:
|
||||
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
|
||||
# sudo apt-get install libjpeg-dev
|
||||
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
img = center_crop(img)
|
||||
img = np.array(img)
|
||||
if val:
|
||||
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
|
||||
# sudo apt-get install libjpeg-dev
|
||||
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
img = center_crop(img)
|
||||
img = np.array(img)
|
||||
else:
|
||||
# reseed rng for determinism
|
||||
if seed is not None:
|
||||
np.random.seed(seed * 2 ** 20 + idx)
|
||||
random.seed(seed * 2 ** 20 + idx)
|
||||
img = preprocess_train(img)
|
||||
else:
|
||||
# reseed rng for determinism
|
||||
if seed is not None:
|
||||
np.random.seed(seed * 2 ** 20 + idx)
|
||||
random.seed(seed * 2 ** 20 + idx)
|
||||
img = preprocess_train(img)
|
||||
# pad zeros
|
||||
img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
|
||||
# broken out
|
||||
#img_tensor = Tensor(img.tobytes(), device='CPU')
|
||||
@@ -73,19 +77,35 @@ def loader_process(q_in, q_out, X:Tensor, seed):
|
||||
q_out.put(idx)
|
||||
q_out.put(None)
|
||||
|
||||
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
|
||||
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_first_batch=False):
|
||||
from extra.datasets.imagenet import get_train_files, get_val_files
|
||||
files = get_val_files() if val else get_train_files()
|
||||
from extra.datasets.imagenet import get_imagenet_categories
|
||||
cir = get_imagenet_categories()
|
||||
BATCH_COUNT = min(32, len(files) // batch_size)
|
||||
|
||||
gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
|
||||
if pad_first_batch:
|
||||
FIRST_BATCH_PAD = round_up(len(files), batch_size) - len(files)
|
||||
else:
|
||||
FIRST_BATCH_PAD = 0
|
||||
file_count = FIRST_BATCH_PAD + len(files)
|
||||
BATCH_COUNT = min(32, file_count // batch_size)
|
||||
|
||||
def _gen():
|
||||
for _ in range(FIRST_BATCH_PAD): yield -1
|
||||
yield from shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
|
||||
gen = iter(_gen())
|
||||
|
||||
def enqueue_batch(num):
|
||||
for idx in range(num*batch_size, (num+1)*batch_size):
|
||||
fn = files[next(gen)]
|
||||
q_in.put((idx, fn, val))
|
||||
Y[idx] = cir[fn.split("/")[-2]]
|
||||
fidx = next(gen)
|
||||
if fidx != -1:
|
||||
fn = files[fidx]
|
||||
q_in.put((idx, fn, val))
|
||||
Y[idx] = cir[fn.split("/")[-2]]
|
||||
else:
|
||||
# padding
|
||||
q_in.put((idx, None, val))
|
||||
Y[idx] = -1
|
||||
|
||||
shutdown = False
|
||||
class Cookie:
|
||||
@@ -126,8 +146,8 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
|
||||
|
||||
for bn in range(BATCH_COUNT): enqueue_batch(bn)
|
||||
|
||||
# NOTE: this is batch aligned, last ones are ignored
|
||||
for _ in range(0, len(files)//batch_size): yield receive_batch()
|
||||
# NOTE: this is batch aligned, last ones are ignored unless pad_first_batch is True
|
||||
for _ in range(0, file_count//batch_size): yield receive_batch()
|
||||
finally:
|
||||
shutdown = True
|
||||
# empty queues
|
||||
|
||||
@@ -35,7 +35,7 @@ def eval_resnet():
|
||||
|
||||
# evaluation on the mlperf classes of the validation set from imagenet
|
||||
from examples.mlperf.dataloader import batch_load_resnet
|
||||
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False)
|
||||
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False, pad_first_batch=True)
|
||||
def data_get():
|
||||
x,y,cookie = next(iterator)
|
||||
return x.shard(GPUS, axis=0).realize(), y, cookie
|
||||
@@ -51,9 +51,10 @@ def eval_resnet():
|
||||
try: next_proc = data_get()
|
||||
except StopIteration: next_proc = None
|
||||
nd = time.perf_counter()
|
||||
proc = proc[0].numpy() == proc[1] # this realizes the models and frees the cookies
|
||||
y = np.array(proc[1])
|
||||
proc = (proc[0].numpy() == y) & (y != -1) # this realizes the models and frees the cookies
|
||||
n += proc.sum()
|
||||
d += proc.size
|
||||
d += (y != -1).sum()
|
||||
et = time.perf_counter()
|
||||
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
|
||||
st = et
|
||||
|
||||
Reference in New Issue
Block a user