Files
tinygrad/examples/mlperf/dataloader.py
David Hou 199f7c4342 MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff

TEST_TRAIN env for less data

don't diskcache get_train_files

debug message

no lr_scaler for fp32

comment, typo

type stuff

don't destructure proc

make batchnorm parameters float

make batchnorm parameters float

resnet18, checkpointing

hack up checkpointing to keep the names in there

oops

wandb_resume

lower lr

eval/ckpt use e+1

lars

report top_1_acc

some wandb stuff

split fw and bw steps to save memory

oops

save model when reach target

formatting

make sgd hparams consistent

just always write the cats tag...

pass X and Y into backward_step to trigger input replace

shuffle eval set to fix batchnorm eval

dataset is sorted by class, so the means and variances are all wrong

small cleanup

hack restore only one copy of each tensor

do bufs from lin after cache check (lru should handle it fine)

record epoch in wandb

more digits for topk in eval

more env vars

small cleanup

cleanup hack tricks

cleanup hack tricks

don't save ckpt for testeval

cleanup

diskcache train file glob

clean up a little

device_str

SCE into tensor

small

small

log_softmax out of resnet.py

oops

hack :(

comments

HeNormal, track gradient norm

oops

log SYNCBN to wandb

real truncnorm

less samples for truncated normal

custom init for Linear

log layer stats

small

Revert "small"

This reverts commit 988f4c1cf3.

Revert "log layer stats"

This reverts commit 9d98224585.

rename BNSYNC to SYNCBN to be consistent with cifar

optional TRACK_NORMS

fix label smoothing :/

lars skip list

only weight decay if not in skip list

comment

default 0 TRACK_NORMS

don't allocate beam scratch buffers if in cache

clean up data pipeline, unsplit train/test, put back a hack

remove print

run test_indexing on remu (#3404)

* emulated ops_hip infra

* add int4

* include test_indexing in remu

* Revert "Merge branch 'remu-dev-mac'"

This reverts commit 6870457e57, reversing
changes made to 3c4c8c9e16.

fix bad seeding

UnsyncBatchNorm2d but with synced trainable weights

label downsample batchnorm in Bottleneck

:/

:/

i mean... it runs... its hits the acc... its fast...

new unsyncbatchnorm for resnet

small fix

don't do assign buffer reuse for axis change

* remove changes

* remove changes

* move LARS out of tinygrad/

* rand_truncn rename

* whitespace

* stray whitespace

* no more gnorms

* delete some dataloading stuff

* remove comment

* clean up train script

* small comments

* move checkpointing stuff to mlperf helpers

* if WANDB

* small comments

* remove whitespace change

* new unsynced bn

* clean up prints / loop vars

* whitespace

* undo nn changes

* clean up loops

* rearrange getenvs

* cpu_count()

* PolynomialLR whitespace

* move he_normal out

* cap warmup in polylr

* rearrange wandb log

* realize both x and y in data_get

* use double quotes

* combine prints in ckpts resume

* take UBN from cifar

* running_var

* whitespace

* whitespace

* typo

* if instead of ternary for resnet downsample

* clean up dataloader cleanup a little?

* separate rng for shuffle

* clean up imports in model_train

* clean up imports

* don't realize copyin in data_get

* remove TESTEVAL (train dataloader didn't get freed every loop)

* adjust wandb_config entries a little

* clean up wandb config dict

* reduce lines

* whitespace

* shorter lines

* put shm unlink back, but it doesn't seem to do anything

* don't pass seed per task

* monkeypatch batchnorm

* the reseed was wrong

* add epoch number to desc

* don't unsyncedbatchnorm is syncbn=1

* put back downsample name

* eval every epoch

* Revert "the reseed was wrong"

This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.

* cast lr in onecycle

* support fp16

* cut off kernel if expand after reduce

* test polynomial lr

* move polynomiallr to examples/mlperf

* working PolynomialDecayWithWarmup + tests.......

add lars_util.py, oops

* keep lars_util.py as intact as possible, simplify our interface

* no more half

* polylr and lars were merged

* undo search change

* override Linear init

* remove half stuff from model_train

* update scheduler init with new args

* don't divide by input mean

* mistake in resnet.py

* restore whitespace in resnet.py

* add test_data_parallel_resnet_train_step

* move initializers out of resnet.py

* unused imports

* log_softmax to model output in test to fix precision flakiness

* log_softmax to model output in test to fix precision flakiness

* oops, don't realize here

* is None

* realize initializations in order for determinism

* BENCHMARK flag for number of steps

* add resnet to bechmark.yml

* return instead of break

* missing return

* cpu_count, rearrange benchmark.yml

* unused variable

* disable tqdm if BENCHMARK

* getenv WARMUP_EPOCHS

* unlink disktensor shm file if exists

* terminate instead of join

* properly shut down queues

* use hip in benchmark for now

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 00:53:41 -04:00

149 lines
4.9 KiB
Python

import os, random
import numpy as np
from PIL import Image
from tqdm import tqdm
import pickle
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Timing, Context
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
class MyQueue:
def __init__(self, multiple_readers=True, multiple_writers=True):
self._reader, self._writer = connection.Pipe(duplex=False)
self._rlock = Lock() if multiple_readers else None
self._wlock = Lock() if multiple_writers else None
def get(self):
if self._rlock: self._rlock.acquire()
ret = pickle.loads(self._reader.recv_bytes())
if self._rlock: self._rlock.release()
return ret
def put(self, obj):
if self._wlock: self._wlock.acquire()
self._writer.send_bytes(pickle.dumps(obj))
if self._wlock: self._wlock.release()
def shuffled_indices(n, seed=None):
rng = random.Random(seed)
indices = {}
for i in range(n-1, -1, -1):
j = rng.randint(0, i)
if i not in indices: indices[i] = i
if j not in indices: indices[j] = j
indices[i], indices[j] = indices[j], indices[i]
yield indices[i]
del indices[i]
def loader_process(q_in, q_out, X:Tensor, seed):
import signal
signal.signal(signal.SIGINT, lambda _, __: exit(0))
from extra.datasets.imagenet import center_crop, preprocess_train
with Context(DEBUG=0):
while (_recv := q_in.get()) is not None:
idx, fn, val = _recv
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
if val:
# eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
# sudo apt-get install libjpeg-dev
# CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
img = center_crop(img)
img = np.array(img)
else:
# reseed rng for determinism
if seed is not None:
np.random.seed(seed * 2 ** 20 + idx)
random.seed(seed * 2 ** 20 + idx)
img = preprocess_train(img)
# broken out
#img_tensor = Tensor(img.tobytes(), device='CPU')
#storage_tensor = X[idx].contiguous().realize().lazydata.realized
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
q_out.put(idx)
q_out.put(None)
def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
from extra.datasets.imagenet import get_train_files, get_val_files
files = get_val_files() if val else get_train_files()
from extra.datasets.imagenet import get_imagenet_categories
cir = get_imagenet_categories()
BATCH_COUNT = min(32, len(files) // batch_size)
gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
def enqueue_batch(num):
for idx in range(num*batch_size, (num+1)*batch_size):
fn = files[next(gen)]
q_in.put((idx, fn, val))
Y[idx] = cir[fn.split("/")[-2]]
shutdown = False
class Cookie:
def __init__(self, num): self.num = num
def __del__(self):
if not shutdown:
try: enqueue_batch(self.num)
except StopIteration: pass
gotten = [0]*BATCH_COUNT
def receive_batch():
while 1:
num = q_out.get()//batch_size
gotten[num] += 1
if gotten[num] == batch_size: break
gotten[num] = 0
return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)
#q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
q_in, q_out = Queue(), Queue()
sz = (batch_size*BATCH_COUNT, 224, 224, 3)
if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
procs = []
try:
# disk:shm is slower
#X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
Y = [None] * (batch_size*BATCH_COUNT)
for _ in range(cpu_count()):
p = Process(target=loader_process, args=(q_in, q_out, X, seed))
p.daemon = True
p.start()
procs.append(p)
for bn in range(BATCH_COUNT): enqueue_batch(bn)
# NOTE: this is batch aligned, last ones are ignored
for _ in range(0, len(files)//batch_size): yield receive_batch()
finally:
shutdown = True
# empty queues
for _ in procs: q_in.put(None)
q_in.close()
for _ in procs:
while q_out.get() is not None: pass
q_out.close()
# shutdown processes
for p in procs: p.join()
shm.close()
shm.unlink()
if __name__ == "__main__":
from extra.datasets.imagenet import get_train_files, get_val_files
VAL = getenv("VAL", 1)
files = get_val_files() if VAL else get_train_files()
with tqdm(total=len(files)) as pbar:
for x,y,c in batch_load_resnet(val=VAL):
pbar.update(x.shape[0])