Files
tinygrad/examples/mlperf/dataloader.py
Francis Lata 707099487a Multiprocessing UNet3D dataloader (#4801)
* testing dataloader

* matching dataloader implementation for unet3d

* remove comments

* clean up dataloader

* add cookie and cleanup

* use shm_path when creating SharedMemory

* add support for testing resnet and unet3d dataloaders

* update dataset test to return preprocesed data directory in prep for dataloader testing

* pass preprocessed dataset directory properly

* update loader function for dataloader

* add shuffling on indices

* update shm name

* more cleanup for unet3d dataloader

* remove changes to tests

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-06-02 11:30:47 -04:00

375 lines
13 KiB
Python

import os, random, pickle, functools, itertools
from typing import List, Tuple
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Context, round_up
from collections import deque
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
class MyQueue:
def __init__(self, multiple_readers=True, multiple_writers=True):
self._reader, self._writer = connection.Pipe(duplex=False)
self._rlock = Lock() if multiple_readers else None
self._wlock = Lock() if multiple_writers else None
def get(self):
if self._rlock: self._rlock.acquire()
ret = pickle.loads(self._reader.recv_bytes())
if self._rlock: self._rlock.release()
return ret
def put(self, obj):
if self._wlock: self._wlock.acquire()
self._writer.send_bytes(pickle.dumps(obj))
if self._wlock: self._wlock.release()
def shuffled_indices(n, seed=None):
rng = random.Random(seed)
indices = {}
for i in range(n-1, -1, -1):
j = rng.randint(0, i)
if i not in indices: indices[i] = i
if j not in indices: indices[j] = j
indices[i], indices[j] = indices[j], indices[i]
yield indices[i]
del indices[i]
def loader_process(q_in, q_out, X:Tensor, seed):
import signal
signal.signal(signal.SIGINT, lambda _, __: exit(0))
from extra.datasets.imagenet import center_crop, preprocess_train
with Context(DEBUG=0):
while (_recv := q_in.get()) is not None:
idx, fn, val = _recv
if fn is not None:
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
if val:
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
# sudo apt-get install libjpeg-dev
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
img = center_crop(img)
img = np.array(img)
else:
# reseed rng for determinism
if seed is not None:
np.random.seed(seed * 2 ** 10 + idx)
random.seed(seed * 2 ** 10 + idx)
img = preprocess_train(img)
else:
# pad data with training mean
img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
# broken out
#img_tensor = Tensor(img.tobytes(), device='CPU')
#storage_tensor = X[idx].contiguous().realize().lazydata.realized
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
q_out.put(idx)
q_out.put(None)
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_first_batch=False):
from extra.datasets.imagenet import get_train_files, get_val_files
files = get_val_files() if val else get_train_files()
from extra.datasets.imagenet import get_imagenet_categories
cir = get_imagenet_categories()
if pad_first_batch:
FIRST_BATCH_PAD = round_up(len(files), batch_size) - len(files)
else:
FIRST_BATCH_PAD = 0
file_count = FIRST_BATCH_PAD + len(files)
BATCH_COUNT = min(32, file_count // batch_size)
def _gen():
for _ in range(FIRST_BATCH_PAD): yield -1
yield from shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
gen = iter(_gen())
def enqueue_batch(num):
for idx in range(num*batch_size, (num+1)*batch_size):
fidx = next(gen)
if fidx != -1:
fn = files[fidx]
q_in.put((idx, fn, val))
Y[idx] = cir[fn.split("/")[-2]]
else:
# padding
q_in.put((idx, None, val))
Y[idx] = -1
shutdown = False
class Cookie:
def __init__(self, num): self.num = num
def __del__(self):
if not shutdown:
try: enqueue_batch(self.num)
except StopIteration: pass
gotten = [0]*BATCH_COUNT
def receive_batch():
while 1:
num = q_out.get()//batch_size
gotten[num] += 1
if gotten[num] == batch_size: break
gotten[num] = 0
return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)
#q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
q_in, q_out = Queue(), Queue()
sz = (batch_size*BATCH_COUNT, 224, 224, 3)
if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
procs = []
try:
# disk:shm is slower
#X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
Y = [None] * (batch_size*BATCH_COUNT)
for _ in range(cpu_count()):
p = Process(target=loader_process, args=(q_in, q_out, X, seed))
p.daemon = True
p.start()
procs.append(p)
for bn in range(BATCH_COUNT): enqueue_batch(bn)
# NOTE: this is batch aligned, last ones are ignored unless pad_first_batch is True
for _ in range(0, file_count//batch_size): yield receive_batch()
finally:
shutdown = True
# empty queues
for _ in procs: q_in.put(None)
q_in.close()
for _ in procs:
while q_out.get() is not None: pass
q_out.close()
# shutdown processes
for p in procs: p.join()
shm.close()
try:
shm.unlink()
except FileNotFoundError:
# happens with BENCHMARK set
pass
@functools.lru_cache(maxsize=128)
def load_bert_file(fn:str) -> List[dict]:
with open(fn, "rb") as f: data = pickle.load(f)
return data
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
return {
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.default_float),
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32),
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
}
def shuffle_parts(file_paths: List[str]) -> List[str]:
parts = {}
for f in file_paths:
part = Path(f).stem.split('_')[0]
if part not in parts: parts[part] = []
parts[part].append(f)
part_ids = list(parts.keys())
random.shuffle(part_ids)
shuffled_files = []
for p in part_ids:
parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
shuffled_files.extend(parts[p])
return shuffled_files
def random_sample(data: List[str]):
index = random.randint(0, len(data) - 1)
selected_sample = data[index]
return selected_sample, index
def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
data = load_bert_file(file_and_offset[0])
return data[file_and_offset[1]]
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
def batch_load_train_bert(BS:int, start_step:int = 0):
from extra.datasets.wikipedia import get_wiki_train_files
files = shuffle_parts(get_wiki_train_files())
dataset = []
for f in tqdm(files, desc="Building dataset"):
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
dataset.extend(lists)
dataset = dataset[start_step:]
active_set = deque(dataset[:1000])
remaining_set = deque(dataset[1000:])
while dataset:
blob = []
for _ in range(BS):
if active_set:
index = random.randint(0, len(active_set) - 1)
sample = active_set[index]
active_set.remove(sample)
blob.append(sample)
if remaining_set:
active_set.append(remaining_set.popleft())
yield process_batch_bert([load_datasample(sample) for sample in blob])
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
def batch_load_val_bert(BS:int):
from extra.datasets.wikipedia import get_wiki_val_files
files = get_wiki_val_files()
dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
idx = 0
while True:
start_idx = (idx * BS) % len(dataset)
end_idx = ((idx + 1) * BS) % len(dataset)
if start_idx < end_idx:
yield process_batch_bert(dataset[start_idx:end_idx])
else: # wrap around the end to the beginning of the dataset
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
idx += 1
def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tensor, Y:Tensor):
from extra.datasets.kits19 import rand_balanced_crop, rand_flip, random_brightness_augmentation, gaussian_noise
while (data := queue_in.get()) is not None:
idx, fn, val = data
case_name = os.path.basename(fn).split("_x.npy")[0]
x, y = np.load(preprocessed_dataset_dir / f"{case_name}_x.npy"), np.load(preprocessed_dataset_dir / f"{case_name}_y.npy")
if not val:
if seed is not None:
np.random.seed(seed)
random.seed(seed)
x, y = rand_balanced_crop(x, y)
x, y = rand_flip(x, y)
x, y = x.astype(np.float32), y.astype(np.uint8)
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
queue_out.put(idx)
queue_out.put(None)
def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=False, shuffle:bool=True, seed=None):
assert preprocessed_dataset_dir is not None, "run preprocess_data on kits19"
files = sorted(list(preprocessed_dataset_dir.glob("*_x.npy")))
file_indices = list(range(len(files)))
batch_count = min(32, len(files) // batch_size)
queue_in, queue_out = Queue(), Queue()
procs, data_out_count = [], [0] * batch_count
shm_name = "unet3d"
sz, shm_path = (batch_size * batch_count, 1, 128, 128, 128), f"/dev/shm/{shm_name}"
if os.path.exists(shm_path): os.unlink(shm_path)
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=prod(sz))
shutdown = False
class Cookie:
def __init__(self, bc):
self.bc = bc
def __del__(self):
if not shutdown:
try: enqueue_batch(self.bc)
except StopIteration: pass
def enqueue_batch(bc):
for idx in range(bc * batch_size, (bc+1) * batch_size):
fn = files[next(ds_iter)]
queue_in.put((idx, fn, val))
def shuffle_indices(file_indices, seed=None):
rng = random.Random(seed)
rng.shuffle(file_indices)
if shuffle: shuffle_indices(file_indices, seed=seed)
ds_iter = iter(file_indices)
try:
X = Tensor.empty(*sz, dtype=dtypes.float32, device=f"disk:{shm_path}")
Y = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:{shm_path}")
for _ in range(cpu_count()):
proc = Process(target=load_unet3d_data, args=(preprocessed_dataset_dir, seed, queue_in, queue_out, X, Y))
proc.daemon = True
proc.start()
procs.append(proc)
for bc in range(batch_count):
enqueue_batch(bc)
for _ in range(len(files) // batch_size):
while True:
bc = queue_out.get() // batch_size
data_out_count[bc] += 1
if data_out_count[bc] == batch_size: break
data_out_count[bc] = 0
yield X[bc * batch_size:(bc + 1) * batch_size], Y[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
finally:
shutdown = True
for _ in procs: queue_in.put(None)
queue_in.close()
for _ in procs:
while queue_out.get() is not None: pass
queue_out.close()
# shutdown processes
for proc in procs: proc.join()
shm.close()
try:
shm.unlink()
except FileNotFoundError:
# happens with BENCHMARK set
pass
if __name__ == "__main__":
def load_unet3d(val):
assert not val, "validation set is not supported due to different sizes on inputs"
from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, BASEDIR
preprocessed_dataset_dir = (BASEDIR / ".." / "preprocessed" / ("val" if val else "train"))
files = get_val_files() if val else get_train_files()
if not preprocessed_dataset_dir.exists(): preprocess_dataset(files, preprocessed_dataset_dir, val)
with tqdm(total=len(files)) as pbar:
for x, _, _ in batch_load_unet3d(preprocessed_dataset_dir, val=val):
pbar.update(x.shape[0])
def load_resnet(val):
from extra.datasets.imagenet import get_train_files, get_val_files
files = get_val_files() if val else get_train_files()
with tqdm(total=len(files)) as pbar:
for x,y,c in batch_load_resnet(val=val):
pbar.update(x.shape[0])
load_fn_name = f"load_{getenv('MODEL', 'resnet')}"
if load_fn_name in globals():
globals()[load_fn_name](getenv("VAL", 1))