diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index f4869de422..95f5e4beac 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -4,7 +4,7 @@ from PIL import Image from tqdm import tqdm import pickle from tinygrad import dtypes, Tensor -from tinygrad.helpers import getenv, prod, Timing +from tinygrad.helpers import getenv, prod, Timing, Context from multiprocessing import Queue, Process, shared_memory, connection, Lock class MyQueue: @@ -33,30 +33,31 @@ def shuffled_indices(n): del indices[i] def loader_process(q_in, q_out, X:Tensor): - while (_recv := q_in.get()) is not None: - idx, fn = _recv - img = Image.open(fn) - img = img.convert('RGB') if img.mode != "RGB" else img + with Context(DEBUG=0): + while (_recv := q_in.get()) is not None: + idx, fn = _recv + img = Image.open(fn) + img = img.convert('RGB') if img.mode != "RGB" else img - # eval: 76.08%, load in 0m7.366s (0m5.301s with simd) - # sudo apt-get install libjpeg-dev - # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd - rescale = min(img.size) / 256 - crop_left = (img.width - 224*rescale) / 2.0 - crop_top = (img.height - 224*rescale) / 2.0 - img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left+224*rescale, crop_top+224*rescale)) + # eval: 76.08%, load in 0m7.366s (0m5.301s with simd) + # sudo apt-get install libjpeg-dev + # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd + rescale = min(img.size) / 256 + crop_left = (img.width - 224*rescale) / 2.0 + crop_top = (img.height - 224*rescale) / 2.0 + img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left+224*rescale, crop_top+224*rescale)) - # broken out - #img_tensor = Tensor(img.tobytes(), device='CPU') - #storage_tensor = X[idx].contiguous().realize().lazydata.realized - #storage_tensor._copyin(img_tensor.numpy()) + # broken out + #img_tensor = Tensor(img.tobytes(), device='CPU') + #storage_tensor = X[idx].contiguous().realize().lazydata.realized + #storage_tensor._copyin(img_tensor.numpy()) - # faster - X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() + # faster + X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() - # ideal - #X[idx].assign(img.tobytes()) # NOTE: this is slow! - q_out.put(idx) + # ideal + #X[idx].assign(img.tobytes()) # NOTE: this is slow! + q_out.put(idx) def batch_load_resnet(batch_size=64, val=False, shuffle=True): from extra.datasets.imagenet import get_train_files, get_val_files diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index fcb62569c8..1f487342a9 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -53,18 +53,26 @@ class HIPAllocator(LRUAllocator): def _hostalloc(self, size:int): return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 0))) def copy_from_fd(self, dest, fd, offset, size): check(hip.hipSetDevice(self.device.device)) - if not hasattr(self, 'hb'): self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)] + if not hasattr(self, 'hb'): + self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)] + self.hb_events = [None, None] + self.hb_polarity = 0 fo = io.FileIO(fd, "a+b", closefd=False) fo.seek(offset - (minor_offset:=offset % PAGE_SIZE)) copied_in = 0 for local_offset in range(0, size+minor_offset, CHUNK_SIZE): local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE) - fo.readinto(to_mv(self.hb[0], local_size)) - check(hip.hipDeviceSynchronize()) - check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[0].value + minor_offset), + if self.hb_events[self.hb_polarity] is not None: + check(hip.hipEventSynchronize(self.hb_events[self.hb_polarity])) + check(hip.hipEventDestroy(self.hb_events[self.hb_polarity])) + self.hb_events[self.hb_polarity] = None + fo.readinto(to_mv(self.hb[self.hb_polarity], local_size)) + check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[self.hb_polarity].value + minor_offset), copy_size:=min(local_size-minor_offset, size-copied_in), hip.hipMemcpyHostToDevice, None)) + self.hb_events[self.hb_polarity] = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x)))) + check(hip.hipEventRecord(self.hb_events[self.hb_polarity], None)) copied_in += copy_size - self.hb = self.hb[1:] + [self.hb[0]] + self.hb_polarity = (self.hb_polarity+1) % len(self.hb) minor_offset = 0 # only on the first def copyin(self, dest:T, src: memoryview): check(hip.hipSetDevice(self.device.device))