import os, random from typing import List import numpy as np from PIL import Image from tqdm import tqdm import pickle from tinygrad import dtypes, Tensor from tinygrad.helpers import getenv, prod, Timing, Context from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count class MyQueue: def __init__(self, multiple_readers=True, multiple_writers=True): self._reader, self._writer = connection.Pipe(duplex=False) self._rlock = Lock() if multiple_readers else None self._wlock = Lock() if multiple_writers else None def get(self): if self._rlock: self._rlock.acquire() ret = pickle.loads(self._reader.recv_bytes()) if self._rlock: self._rlock.release() return ret def put(self, obj): if self._wlock: self._wlock.acquire() self._writer.send_bytes(pickle.dumps(obj)) if self._wlock: self._wlock.release() def shuffled_indices(n, seed=None): rng = random.Random(seed) indices = {} for i in range(n-1, -1, -1): j = rng.randint(0, i) if i not in indices: indices[i] = i if j not in indices: indices[j] = j indices[i], indices[j] = indices[j], indices[i] yield indices[i] del indices[i] def loader_process(q_in, q_out, X:Tensor, seed): import signal signal.signal(signal.SIGINT, lambda _, __: exit(0)) from extra.datasets.imagenet import center_crop, preprocess_train with Context(DEBUG=0): while (_recv := q_in.get()) is not None: idx, fn, val = _recv img = Image.open(fn) img = img.convert('RGB') if img.mode != "RGB" else img if val: # eval: 76.08%, load in 0m7.366s (0m5.301s with simd) # sudo apt-get install libjpeg-dev # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd img = center_crop(img) img = np.array(img) else: # reseed rng for determinism if seed is not None: np.random.seed(seed * 2 ** 20 + idx) random.seed(seed * 2 ** 20 + idx) img = preprocess_train(img) # broken out #img_tensor = Tensor(img.tobytes(), device='CPU') #storage_tensor = X[idx].contiguous().realize().lazydata.realized #storage_tensor._copyin(img_tensor.numpy()) # faster X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() # ideal #X[idx].assign(img.tobytes()) # NOTE: this is slow! q_out.put(idx) q_out.put(None) def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None): from extra.datasets.imagenet import get_train_files, get_val_files files = get_val_files() if val else get_train_files() from extra.datasets.imagenet import get_imagenet_categories cir = get_imagenet_categories() BATCH_COUNT = min(32, len(files) // batch_size) gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files))) def enqueue_batch(num): for idx in range(num*batch_size, (num+1)*batch_size): fn = files[next(gen)] q_in.put((idx, fn, val)) Y[idx] = cir[fn.split("/")[-2]] shutdown = False class Cookie: def __init__(self, num): self.num = num def __del__(self): if not shutdown: try: enqueue_batch(self.num) except StopIteration: pass gotten = [0]*BATCH_COUNT def receive_batch(): while 1: num = q_out.get()//batch_size gotten[num] += 1 if gotten[num] == batch_size: break gotten[num] = 0 return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num) #q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False) q_in, q_out = Queue(), Queue() sz = (batch_size*BATCH_COUNT, 224, 224, 3) if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X") shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz)) procs = [] try: # disk:shm is slower #X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}") X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X") Y = [None] * (batch_size*BATCH_COUNT) for _ in range(cpu_count()): p = Process(target=loader_process, args=(q_in, q_out, X, seed)) p.daemon = True p.start() procs.append(p) for bn in range(BATCH_COUNT): enqueue_batch(bn) # NOTE: this is batch aligned, last ones are ignored for _ in range(0, len(files)//batch_size): yield receive_batch() finally: shutdown = True # empty queues for _ in procs: q_in.put(None) q_in.close() for _ in procs: while q_out.get() is not None: pass q_out.close() # shutdown processes for p in procs: p.join() shm.close() shm.unlink() def load_bert_file(fn:str) -> List[dict]: with open(fn, "rb") as f: data = pickle.load(f) return data def process_batch_bert(data: List[dict]) -> dict[str, Tensor]: return { "input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32), "input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.float32), "segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32), "masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32), "masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32), "masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32), "next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32), } # For train: Stop when we run through all data # For val: Wrap around val dataset and never stop # Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 420 def batch_load_bert(BS:int, val=False): from extra.datasets.wikipedia import get_wiki_train_files, get_wiki_val_files files = get_wiki_val_files() if val else get_wiki_train_files() blob, end = [], False while files: # As long as there is data, keep going while len(blob) < BS and not end: # Fill blob until there is enough for next step blob.extend(load_bert_file(files.pop(0))) if not files: if val: files = get_val_files() else: end = True # End of train data - avoid pop on empty file list if len(blob) >= BS: # if last train step does not have enough for a full batch yield process_batch_bert(blob[:BS]) blob = blob[BS:] if __name__ == "__main__": from extra.datasets.imagenet import get_train_files, get_val_files VAL = getenv("VAL", 1) files = get_val_files() if VAL else get_train_files() with tqdm(total=len(files)) as pbar: for x,y,c in batch_load_resnet(val=VAL): pbar.update(x.shape[0])