mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
182 lines
6.8 KiB
Python
182 lines
6.8 KiB
Python
import os, random
|
|
from typing import List
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import pickle
|
|
from tinygrad import dtypes, Tensor
|
|
from tinygrad.helpers import getenv, prod, Timing, Context
|
|
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
|
|
|
|
class MyQueue:
|
|
def __init__(self, multiple_readers=True, multiple_writers=True):
|
|
self._reader, self._writer = connection.Pipe(duplex=False)
|
|
self._rlock = Lock() if multiple_readers else None
|
|
self._wlock = Lock() if multiple_writers else None
|
|
def get(self):
|
|
if self._rlock: self._rlock.acquire()
|
|
ret = pickle.loads(self._reader.recv_bytes())
|
|
if self._rlock: self._rlock.release()
|
|
return ret
|
|
def put(self, obj):
|
|
if self._wlock: self._wlock.acquire()
|
|
self._writer.send_bytes(pickle.dumps(obj))
|
|
if self._wlock: self._wlock.release()
|
|
|
|
def shuffled_indices(n, seed=None):
|
|
rng = random.Random(seed)
|
|
indices = {}
|
|
for i in range(n-1, -1, -1):
|
|
j = rng.randint(0, i)
|
|
if i not in indices: indices[i] = i
|
|
if j not in indices: indices[j] = j
|
|
indices[i], indices[j] = indices[j], indices[i]
|
|
yield indices[i]
|
|
del indices[i]
|
|
|
|
def loader_process(q_in, q_out, X:Tensor, seed):
|
|
import signal
|
|
signal.signal(signal.SIGINT, lambda _, __: exit(0))
|
|
|
|
from extra.datasets.imagenet import center_crop, preprocess_train
|
|
|
|
with Context(DEBUG=0):
|
|
while (_recv := q_in.get()) is not None:
|
|
idx, fn, val = _recv
|
|
img = Image.open(fn)
|
|
img = img.convert('RGB') if img.mode != "RGB" else img
|
|
|
|
if val:
|
|
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
|
|
# sudo apt-get install libjpeg-dev
|
|
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
|
img = center_crop(img)
|
|
img = np.array(img)
|
|
else:
|
|
# reseed rng for determinism
|
|
if seed is not None:
|
|
np.random.seed(seed * 2 ** 20 + idx)
|
|
random.seed(seed * 2 ** 20 + idx)
|
|
img = preprocess_train(img)
|
|
|
|
# broken out
|
|
#img_tensor = Tensor(img.tobytes(), device='CPU')
|
|
#storage_tensor = X[idx].contiguous().realize().lazydata.realized
|
|
#storage_tensor._copyin(img_tensor.numpy())
|
|
|
|
# faster
|
|
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
|
|
|
# ideal
|
|
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
|
|
q_out.put(idx)
|
|
q_out.put(None)
|
|
|
|
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
|
|
from extra.datasets.imagenet import get_train_files, get_val_files
|
|
files = get_val_files() if val else get_train_files()
|
|
from extra.datasets.imagenet import get_imagenet_categories
|
|
cir = get_imagenet_categories()
|
|
BATCH_COUNT = min(32, len(files) // batch_size)
|
|
|
|
gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
|
|
def enqueue_batch(num):
|
|
for idx in range(num*batch_size, (num+1)*batch_size):
|
|
fn = files[next(gen)]
|
|
q_in.put((idx, fn, val))
|
|
Y[idx] = cir[fn.split("/")[-2]]
|
|
|
|
shutdown = False
|
|
class Cookie:
|
|
def __init__(self, num): self.num = num
|
|
def __del__(self):
|
|
if not shutdown:
|
|
try: enqueue_batch(self.num)
|
|
except StopIteration: pass
|
|
|
|
gotten = [0]*BATCH_COUNT
|
|
def receive_batch():
|
|
while 1:
|
|
num = q_out.get()//batch_size
|
|
gotten[num] += 1
|
|
if gotten[num] == batch_size: break
|
|
gotten[num] = 0
|
|
return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)
|
|
|
|
#q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
|
|
q_in, q_out = Queue(), Queue()
|
|
|
|
sz = (batch_size*BATCH_COUNT, 224, 224, 3)
|
|
if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
|
|
shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
|
|
procs = []
|
|
|
|
try:
|
|
# disk:shm is slower
|
|
#X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
|
|
X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
|
|
Y = [None] * (batch_size*BATCH_COUNT)
|
|
|
|
for _ in range(cpu_count()):
|
|
p = Process(target=loader_process, args=(q_in, q_out, X, seed))
|
|
p.daemon = True
|
|
p.start()
|
|
procs.append(p)
|
|
|
|
for bn in range(BATCH_COUNT): enqueue_batch(bn)
|
|
|
|
# NOTE: this is batch aligned, last ones are ignored
|
|
for _ in range(0, len(files)//batch_size): yield receive_batch()
|
|
finally:
|
|
shutdown = True
|
|
# empty queues
|
|
for _ in procs: q_in.put(None)
|
|
q_in.close()
|
|
for _ in procs:
|
|
while q_out.get() is not None: pass
|
|
q_out.close()
|
|
# shutdown processes
|
|
for p in procs: p.join()
|
|
shm.close()
|
|
shm.unlink()
|
|
|
|
def load_bert_file(fn:str) -> List[dict]:
|
|
with open(fn, "rb") as f: data = pickle.load(f)
|
|
return data
|
|
|
|
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
|
return {
|
|
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
|
|
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.float32),
|
|
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
|
|
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
|
|
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
|
|
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32),
|
|
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
|
|
}
|
|
|
|
# For train: Stop when we run through all data
|
|
# For val: Wrap around val dataset and never stop
|
|
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 420
|
|
def batch_load_bert(BS:int, val=False):
|
|
from extra.datasets.wikipedia import get_wiki_train_files, get_wiki_val_files
|
|
files = get_wiki_val_files() if val else get_wiki_train_files()
|
|
blob, end = [], False
|
|
while files: # As long as there is data, keep going
|
|
while len(blob) < BS and not end: # Fill blob until there is enough for next step
|
|
blob.extend(load_bert_file(files.pop(0)))
|
|
if not files:
|
|
if val: files = get_val_files()
|
|
else: end = True # End of train data - avoid pop on empty file list
|
|
if len(blob) >= BS: # if last train step does not have enough for a full batch
|
|
yield process_batch_bert(blob[:BS])
|
|
blob = blob[BS:]
|
|
|
|
if __name__ == "__main__":
|
|
from extra.datasets.imagenet import get_train_files, get_val_files
|
|
VAL = getenv("VAL", 1)
|
|
files = get_val_files() if VAL else get_train_files()
|
|
with tqdm(total=len(files)) as pbar:
|
|
for x,y,c in batch_load_resnet(val=VAL):
|
|
pbar.update(x.shape[0])
|