Files
tinygrad/examples/mlperf/dataloader.py
Elias Wahl 3a48773f1a BERT dataloader (#4252)
* add dataloader

* comment
2024-04-23 13:44:49 -04:00

182 lines
6.8 KiB
Python

import os, random
from typing import List
import numpy as np
from PIL import Image
from tqdm import tqdm
import pickle
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Timing, Context
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
class MyQueue:
def __init__(self, multiple_readers=True, multiple_writers=True):
self._reader, self._writer = connection.Pipe(duplex=False)
self._rlock = Lock() if multiple_readers else None
self._wlock = Lock() if multiple_writers else None
def get(self):
if self._rlock: self._rlock.acquire()
ret = pickle.loads(self._reader.recv_bytes())
if self._rlock: self._rlock.release()
return ret
def put(self, obj):
if self._wlock: self._wlock.acquire()
self._writer.send_bytes(pickle.dumps(obj))
if self._wlock: self._wlock.release()
def shuffled_indices(n, seed=None):
rng = random.Random(seed)
indices = {}
for i in range(n-1, -1, -1):
j = rng.randint(0, i)
if i not in indices: indices[i] = i
if j not in indices: indices[j] = j
indices[i], indices[j] = indices[j], indices[i]
yield indices[i]
del indices[i]
def loader_process(q_in, q_out, X:Tensor, seed):
import signal
signal.signal(signal.SIGINT, lambda _, __: exit(0))
from extra.datasets.imagenet import center_crop, preprocess_train
with Context(DEBUG=0):
while (_recv := q_in.get()) is not None:
idx, fn, val = _recv
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
if val:
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
# sudo apt-get install libjpeg-dev
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
img = center_crop(img)
img = np.array(img)
else:
# reseed rng for determinism
if seed is not None:
np.random.seed(seed * 2 ** 20 + idx)
random.seed(seed * 2 ** 20 + idx)
img = preprocess_train(img)
# broken out
#img_tensor = Tensor(img.tobytes(), device='CPU')
#storage_tensor = X[idx].contiguous().realize().lazydata.realized
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
q_out.put(idx)
q_out.put(None)
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
from extra.datasets.imagenet import get_train_files, get_val_files
files = get_val_files() if val else get_train_files()
from extra.datasets.imagenet import get_imagenet_categories
cir = get_imagenet_categories()
BATCH_COUNT = min(32, len(files) // batch_size)
gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
def enqueue_batch(num):
for idx in range(num*batch_size, (num+1)*batch_size):
fn = files[next(gen)]
q_in.put((idx, fn, val))
Y[idx] = cir[fn.split("/")[-2]]
shutdown = False
class Cookie:
def __init__(self, num): self.num = num
def __del__(self):
if not shutdown:
try: enqueue_batch(self.num)
except StopIteration: pass
gotten = [0]*BATCH_COUNT
def receive_batch():
while 1:
num = q_out.get()//batch_size
gotten[num] += 1
if gotten[num] == batch_size: break
gotten[num] = 0
return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)
#q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
q_in, q_out = Queue(), Queue()
sz = (batch_size*BATCH_COUNT, 224, 224, 3)
if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
procs = []
try:
# disk:shm is slower
#X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
Y = [None] * (batch_size*BATCH_COUNT)
for _ in range(cpu_count()):
p = Process(target=loader_process, args=(q_in, q_out, X, seed))
p.daemon = True
p.start()
procs.append(p)
for bn in range(BATCH_COUNT): enqueue_batch(bn)
# NOTE: this is batch aligned, last ones are ignored
for _ in range(0, len(files)//batch_size): yield receive_batch()
finally:
shutdown = True
# empty queues
for _ in procs: q_in.put(None)
q_in.close()
for _ in procs:
while q_out.get() is not None: pass
q_out.close()
# shutdown processes
for p in procs: p.join()
shm.close()
shm.unlink()
def load_bert_file(fn:str) -> List[dict]:
with open(fn, "rb") as f: data = pickle.load(f)
return data
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
return {
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.float32),
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32),
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
}
# For train: Stop when we run through all data
# For val: Wrap around val dataset and never stop
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 420
def batch_load_bert(BS:int, val=False):
from extra.datasets.wikipedia import get_wiki_train_files, get_wiki_val_files
files = get_wiki_val_files() if val else get_wiki_train_files()
blob, end = [], False
while files: # As long as there is data, keep going
while len(blob) < BS and not end: # Fill blob until there is enough for next step
blob.extend(load_bert_file(files.pop(0)))
if not files:
if val: files = get_val_files()
else: end = True # End of train data - avoid pop on empty file list
if len(blob) >= BS: # if last train step does not have enough for a full batch
yield process_batch_bert(blob[:BS])
blob = blob[BS:]
if __name__ == "__main__":
from extra.datasets.imagenet import get_train_files, get_val_files
VAL = getenv("VAL", 1)
files = get_val_files() if VAL else get_train_files()
with tqdm(total=len(files)) as pbar:
for x,y,c in batch_load_resnet(val=VAL):
pbar.update(x.shape[0])