mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -4,7 +4,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, prod, Timing
|
||||
from tinygrad.helpers import getenv, prod, Timing, Context
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock
|
||||
|
||||
class MyQueue:
|
||||
@@ -33,30 +33,31 @@ def shuffled_indices(n):
|
||||
del indices[i]
|
||||
|
||||
def loader_process(q_in, q_out, X:Tensor):
|
||||
while (_recv := q_in.get()) is not None:
|
||||
idx, fn = _recv
|
||||
img = Image.open(fn)
|
||||
img = img.convert('RGB') if img.mode != "RGB" else img
|
||||
with Context(DEBUG=0):
|
||||
while (_recv := q_in.get()) is not None:
|
||||
idx, fn = _recv
|
||||
img = Image.open(fn)
|
||||
img = img.convert('RGB') if img.mode != "RGB" else img
|
||||
|
||||
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
|
||||
# sudo apt-get install libjpeg-dev
|
||||
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
rescale = min(img.size) / 256
|
||||
crop_left = (img.width - 224*rescale) / 2.0
|
||||
crop_top = (img.height - 224*rescale) / 2.0
|
||||
img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left+224*rescale, crop_top+224*rescale))
|
||||
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
|
||||
# sudo apt-get install libjpeg-dev
|
||||
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
rescale = min(img.size) / 256
|
||||
crop_left = (img.width - 224*rescale) / 2.0
|
||||
crop_top = (img.height - 224*rescale) / 2.0
|
||||
img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left+224*rescale, crop_top+224*rescale))
|
||||
|
||||
# broken out
|
||||
#img_tensor = Tensor(img.tobytes(), device='CPU')
|
||||
#storage_tensor = X[idx].contiguous().realize().lazydata.realized
|
||||
#storage_tensor._copyin(img_tensor.numpy())
|
||||
# broken out
|
||||
#img_tensor = Tensor(img.tobytes(), device='CPU')
|
||||
#storage_tensor = X[idx].contiguous().realize().lazydata.realized
|
||||
#storage_tensor._copyin(img_tensor.numpy())
|
||||
|
||||
# faster
|
||||
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
# faster
|
||||
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
|
||||
# ideal
|
||||
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
|
||||
q_out.put(idx)
|
||||
# ideal
|
||||
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
|
||||
q_out.put(idx)
|
||||
|
||||
def batch_load_resnet(batch_size=64, val=False, shuffle=True):
|
||||
from extra.datasets.imagenet import get_train_files, get_val_files
|
||||
|
||||
@@ -53,18 +53,26 @@ class HIPAllocator(LRUAllocator):
|
||||
def _hostalloc(self, size:int): return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 0)))
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
if not hasattr(self, 'hb'): self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)]
|
||||
if not hasattr(self, 'hb'):
|
||||
self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)]
|
||||
self.hb_events = [None, None]
|
||||
self.hb_polarity = 0
|
||||
fo = io.FileIO(fd, "a+b", closefd=False)
|
||||
fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
|
||||
copied_in = 0
|
||||
for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
|
||||
local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
|
||||
fo.readinto(to_mv(self.hb[0], local_size))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[0].value + minor_offset),
|
||||
if self.hb_events[self.hb_polarity] is not None:
|
||||
check(hip.hipEventSynchronize(self.hb_events[self.hb_polarity]))
|
||||
check(hip.hipEventDestroy(self.hb_events[self.hb_polarity]))
|
||||
self.hb_events[self.hb_polarity] = None
|
||||
fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
|
||||
check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[self.hb_polarity].value + minor_offset),
|
||||
copy_size:=min(local_size-minor_offset, size-copied_in), hip.hipMemcpyHostToDevice, None))
|
||||
self.hb_events[self.hb_polarity] = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x))))
|
||||
check(hip.hipEventRecord(self.hb_events[self.hb_polarity], None))
|
||||
copied_in += copy_size
|
||||
self.hb = self.hb[1:] + [self.hb[0]]
|
||||
self.hb_polarity = (self.hb_polarity+1) % len(self.hb)
|
||||
minor_offset = 0 # only on the first
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
check(hip.hipSetDevice(self.device.device))
|
||||
|
||||
Reference in New Issue
Block a user