From 199f7c43424e05be7ff901d3dcff485920ba0fd6 Mon Sep 17 00:00:00 2001 From: David Hou Date: Wed, 13 Mar 2024 21:53:41 -0700 Subject: [PATCH] MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .github/workflows/benchmark.yml | 6 + examples/mlperf/dataloader.py | 57 ++++++--- examples/mlperf/helpers.py | 28 +++++ examples/mlperf/initializers.py | 27 +++++ examples/mlperf/model_eval.py | 2 +- examples/mlperf/model_train.py | 209 +++++++++++++++++++++++++++++++- extra/datasets/imagenet.py | 85 ++++++++----- extra/models/resnet.py | 42 ++++--- test/test_multitensor.py | 34 +++++- 9 files changed, 413 insertions(+), 77 deletions(-) create mode 100644 examples/mlperf/initializers.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 1a3d413c8f..41604467fc 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -177,6 +177,10 @@ jobs: run: time HSA=1 HALF=1 STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt - name: Run MLPerf resnet eval on training data run: time HSA=1 MODEL=resnet python3 examples/mlperf/model_eval.py + - name: Run 10 MLPerf ResNet50 training steps (1 gpu) + run: HIP=1 BENCHMARK=10 BS=104 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt + - name: Run 10 MLPerf ResNet50 training steps (6 gpu) + run: HIP=1 BENCHMARK=10 BS=624 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt - uses: actions/upload-artifact@v4 with: name: Speed (AMD) @@ -187,6 +191,8 @@ jobs: train_cifar_half.txt train_cifar_wino.txt train_cifar_one_gpu.txt + train_resnet.txt + train_resnet_one_gpu.txt train_cifar_six_gpu.txt llama_unjitted.txt llama_jitted.txt diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index d8016a02b4..f2830e855b 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -5,7 +5,7 @@ from tqdm import tqdm import pickle from tinygrad import dtypes, Tensor from tinygrad.helpers import getenv, prod, Timing, Context -from multiprocessing import Queue, Process, shared_memory, connection, Lock +from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count class MyQueue: def __init__(self, multiple_readers=True, multiple_writers=True): @@ -22,33 +22,41 @@ class MyQueue: self._writer.send_bytes(pickle.dumps(obj)) if self._wlock: self._wlock.release() -def shuffled_indices(n): +def shuffled_indices(n, seed=None): + rng = random.Random(seed) indices = {} for i in range(n-1, -1, -1): - j = random.randint(0, i) + j = rng.randint(0, i) if i not in indices: indices[i] = i if j not in indices: indices[j] = j indices[i], indices[j] = indices[j], indices[i] yield indices[i] del indices[i] -def loader_process(q_in, q_out, X:Tensor): +def loader_process(q_in, q_out, X:Tensor, seed): import signal signal.signal(signal.SIGINT, lambda _, __: exit(0)) + from extra.datasets.imagenet import center_crop, preprocess_train + with Context(DEBUG=0): while (_recv := q_in.get()) is not None: - idx, fn = _recv + idx, fn, val = _recv img = Image.open(fn) img = img.convert('RGB') if img.mode != "RGB" else img - # eval: 76.08%, load in 0m7.366s (0m5.301s with simd) - # sudo apt-get install libjpeg-dev - # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd - rescale = min(img.size) / 256 - crop_left = (img.width - 224*rescale) / 2.0 - crop_top = (img.height - 224*rescale) / 2.0 - img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left+224*rescale, crop_top+224*rescale)) + if val: + # eval: 76.08%, load in 0m7.366s (0m5.301s with simd) + # sudo apt-get install libjpeg-dev + # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd + img = center_crop(img) + img = np.array(img) + else: + # reseed rng for determinism + if seed is not None: + np.random.seed(seed * 2 ** 20 + idx) + random.seed(seed * 2 ** 20 + idx) + img = preprocess_train(img) # broken out #img_tensor = Tensor(img.tobytes(), device='CPU') @@ -61,26 +69,29 @@ def loader_process(q_in, q_out, X:Tensor): # ideal #X[idx].assign(img.tobytes()) # NOTE: this is slow! q_out.put(idx) + q_out.put(None) -def batch_load_resnet(batch_size=64, val=False, shuffle=True): +def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None): from extra.datasets.imagenet import get_train_files, get_val_files files = get_val_files() if val else get_train_files() from extra.datasets.imagenet import get_imagenet_categories cir = get_imagenet_categories() BATCH_COUNT = min(32, len(files) // batch_size) - gen = shuffled_indices(len(files)) if shuffle else iter(range(len(files))) + gen = shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files))) def enqueue_batch(num): for idx in range(num*batch_size, (num+1)*batch_size): fn = files[next(gen)] - q_in.put((idx, fn)) + q_in.put((idx, fn, val)) Y[idx] = cir[fn.split("/")[-2]] + shutdown = False class Cookie: def __init__(self, num): self.num = num def __del__(self): - try: enqueue_batch(self.num) - except StopIteration: pass + if not shutdown: + try: enqueue_batch(self.num) + except StopIteration: pass gotten = [0]*BATCH_COUNT def receive_batch(): @@ -105,8 +116,8 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True): X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X") Y = [None] * (batch_size*BATCH_COUNT) - for _ in range(64): - p = Process(target=loader_process, args=(q_in, q_out, X)) + for _ in range(cpu_count()): + p = Process(target=loader_process, args=(q_in, q_out, X, seed)) p.daemon = True p.start() procs.append(p) @@ -116,8 +127,14 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True): # NOTE: this is batch aligned, last ones are ignored for _ in range(0, len(files)//batch_size): yield receive_batch() finally: - # shutdown processes + shutdown = True + # empty queues for _ in procs: q_in.put(None) + q_in.close() + for _ in procs: + while q_out.get() is not None: pass + q_out.close() + # shutdown processes for p in procs: p.join() shm.close() shm.unlink() diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index fb773b7e61..b8fe3e0e97 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -2,6 +2,34 @@ from collections import OrderedDict import unicodedata import numpy as np from scipy import signal +from tinygrad.nn import state + +# +# checkpointing utils +# + +def invert_dict(d): return {v: k for k, v in reversed(d.items())} +def dedup_dict(d): return invert_dict(invert_dict(d)) +# store each tensor into the first key it appears in +def get_training_state(model, optimizer, scheduler): + # hack: let get_state_dict walk the tree starting with model, so that the checkpoint keys are + # readable and can be loaded as a model for eval + train_state = {'model': model, 'optimizer': optimizer, 'scheduler': scheduler} + return dedup_dict(state.get_state_dict(train_state)) +def load_training_state(model, optimizer, scheduler, state_dict): + # use fresh model to restore duplicate keys + train_state = {'model': model, 'optimizer': optimizer, 'scheduler': scheduler} + big_dict = state.get_state_dict(train_state) + # hack: put back the dupes + dupe_names = {} + for k, v in big_dict.items(): + if v not in dupe_names: + dupe_names[v] = k + assert k in state_dict + state_dict[k] = state_dict[dupe_names[v]] + # scheduler contains optimizer and all params, load each weight only once + scheduler_state = {'scheduler': scheduler} + state.load_state_dict(scheduler_state, state_dict) def gaussian_kernel(n, std): gaussian_1d = signal.gaussian(n, std) diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py new file mode 100644 index 0000000000..c98cad43e0 --- /dev/null +++ b/examples/mlperf/initializers.py @@ -0,0 +1,27 @@ +import math + +from tinygrad import Tensor, nn +from tinygrad.helpers import prod, argfix + +# rejection sampling truncated randn +def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor: + CNT=8 + x = Tensor.randn(*(*shape, CNT), dtype=dtype, **kwargs) + ctr = Tensor.arange(CNT).reshape((1,) * len(x.shape[:-1]) + (CNT,)).expand(x.shape) + take = (x.abs() <= truncstds).where(ctr, CNT).min(axis=-1, keepdim=True) # set to 0 if no good samples + return (ctr == take).where(x, 0).sum(axis=-1) + +# https://github.com/keras-team/keras/blob/v2.15.0/keras/initializers/initializers.py#L1026-L1065 +def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor: + std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978 + return std * rand_truncn(*shape, **kwargs) + +class Conv2dHeNormal(nn.Conv2d): + def initialize_weight(self, out_channels, in_channels, groups): + return he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0) + +class Linear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features, bias=bias) + self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01) + if bias: self.bias = Tensor.zeros(out_features) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 221a70b27b..7dadc9ef36 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -30,7 +30,7 @@ def eval_resnet(): x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0 x -= self.input_mean x /= self.input_std - return self.mdl(x).argmax(axis=1).realize() + return self.mdl(x).log_softmax().argmax(axis=1).realize() mdl = TinyJit(ResnetRunner(GPUS)) tlog("loaded models") diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 4ad8742366..38455e675d 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1,9 +1,210 @@ -from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv +import functools +import os +import time +from tqdm import tqdm + +from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes +from tinygrad.helpers import getenv, BEAM, WINO +from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save + +from examples.mlperf.helpers import get_training_state, load_training_state def train_resnet(): - # TODO: Resnet50-v1.5 - pass + from extra.models import resnet + from examples.mlperf.dataloader import batch_load_resnet + from extra.datasets.imagenet import get_train_files, get_val_files + from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup + from examples.mlperf.initializers import Conv2dHeNormal, Linear + from examples.hlb_cifar10 import UnsyncedBatchNorm + + config = {} + seed = config["seed"] = getenv("SEED", 42) + Tensor.manual_seed(seed) # seed for weight initialization + + GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))] + print(f"Training on {GPUS}") + for x in GPUS: Device[x] + + # ** model definition and initializers ** + num_classes = 1000 + resnet.Conv2d = Conv2dHeNormal + resnet.Linear = Linear + if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS)) + model = resnet.ResNet50(num_classes) + + # shard weights and initialize in order + for k, x in get_state_dict(model).items(): + if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k): + x.realize().shard_(GPUS, axis=0) + else: + x.realize().to_(GPUS) + parameters = get_parameters(model) + + # ** hyperparameters ** + epochs = config["epochs"] = getenv("EPOCHS", 45) + BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112 + EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS) + base_lr = config["base_lr"] = getenv("LR", 8.4 * (BS/2048)) + lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 5) + decay = config["decay"] = getenv("DECAY", 2e-4) + + target, achieved = getenv("TARGET", 0.759), False + eval_start_epoch = getenv("EVAL_START_EPOCH", 0) + eval_epochs = getenv("EVAL_EPOCHS", 1) + + steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS) + steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS) + + config["BEAM"] = BEAM.value + config["WINO"] = WINO.value + config["SYNCBN"] = getenv("SYNCBN") + + # ** Optimizer ** + from examples.mlperf.optimizers import LARS + skip_list = {v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k} + optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay, skip_list=skip_list) + + # ** LR scheduler ** + scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4, + train_steps=epochs * steps_in_train_epoch, + warmup=lr_warmup_epochs * steps_in_train_epoch) + print(f"training with batch size {BS} for {epochs} epochs") + + # ** resume from checkpointing ** + start_epoch = 0 + if ckpt:=getenv("RESUME", ""): + load_training_state(model, optimizer, scheduler, safe_load(ckpt)) + start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch) + print(f"resuming from {ckpt} at epoch {start_epoch}") + + # ** init wandb ** + WANDB = getenv("WANDB") + if WANDB: + import wandb + wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {} + wandb.init(config=config, **wandb_args) + + # ** jitted steps ** + input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1) + # mlperf reference resnet does not divide by input_std for some reason + # input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1) + def normalize(x): return x.permute([0, 3, 1, 2]) - input_mean + @TinyJit + def train_step(X, Y): + optimizer.zero_grad() + X = normalize(X) + out = model.forward(X) + loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1) + top_1 = (out.argmax(-1) == Y).sum() + loss.backward() + optimizer.step() + scheduler.step() + return loss.realize(), top_1.realize() + @TinyJit + def eval_step(X, Y): + X = normalize(X) + out = model.forward(X) + loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1) + top_1 = (out.argmax(-1) == Y).sum() + return loss.realize(), top_1.realize() + def data_get(it): + x, y, cookie = next(it) + return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie + + # ** epoch loop ** + for e in range(start_epoch, epochs): + # ** train loop ** + Tensor.training = True + it = iter(tqdm(batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e), + total=steps_in_train_epoch, desc=f"epoch {e}", disable=getenv("BENCHMARK"))) + i, proc = 0, data_get(it) + st = time.perf_counter() + while proc is not None: + GlobalCounters.reset() + (loss, top_1_acc), proc = train_step(proc[0], proc[1]), proc[2] + + pt = time.perf_counter() + + try: + next_proc = data_get(it) + except StopIteration: + next_proc = None + + dt = time.perf_counter() + + device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" + loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / BS + + cl = time.perf_counter() + + tqdm.write( + f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, " + f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, " + f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS") + if WANDB: + wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st, + "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt, + "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch}) + + st = cl + proc, next_proc = next_proc, None # return old cookie + i += 1 + + if i == getenv("BENCHMARK"): return + + # ** eval loop ** + if (e + 1 - eval_start_epoch) % eval_epochs == 0: + train_step.reset() # free the train step memory :( + eval_loss = [] + eval_times = [] + eval_top_1_acc = [] + Tensor.training = False + + it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False), total=steps_in_val_epoch)) + proc = data_get(it) + while proc is not None: + GlobalCounters.reset() + st = time.time() + + (loss, top_1_acc), proc = eval_step(proc[0], proc[1]), proc[2] # drop inputs, keep cookie + + try: + next_proc = data_get(it) + except StopIteration: + next_proc = None + + loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / EVAL_BS + eval_loss.append(loss) + eval_top_1_acc.append(top_1_acc) + proc, next_proc = next_proc, None # return old cookie + + et = time.time() + eval_times.append(et - st) + + eval_step.reset() + total_loss = sum(eval_loss) / len(eval_loss) + total_top_1 = sum(eval_top_1_acc) / len(eval_top_1_acc) + total_fw_time = sum(eval_times) / len(eval_times) + tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}") + if WANDB: + wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1}) + + # save model if achieved target + if not achieved and total_top_1 >= target: + fn = f"./ckpts/resnet50.safe" + safe_save(get_state_dict(model), fn) + print(f" *** Model saved to {fn} ***") + achieved = True + + # checkpoint every time we eval + if getenv("CKPT"): + if not os.path.exists("./ckpts"): os.mkdir("./ckpts") + if WANDB and wandb.run is not None: + fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe" + else: + fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe" + print(f"saving ckpt to {fn}") + safe_save(get_training_state(model, optimizer, scheduler), fn) def train_retinanet(): # TODO: Retinanet diff --git a/extra/datasets/imagenet.py b/extra/datasets/imagenet.py index 1770bc12dd..33bbe34100 100644 --- a/extra/datasets/imagenet.py +++ b/extra/datasets/imagenet.py @@ -1,9 +1,9 @@ # for imagenet download prepare.sh and run it -import glob, random, json +import glob, random, json, math import numpy as np from PIL import Image import functools, pathlib -from tinygrad.helpers import DEBUG, diskcache +from tinygrad.helpers import diskcache, getenv BASEDIR = pathlib.Path(__file__).parent / "imagenet" @@ -18,37 +18,58 @@ def get_train_files(): return glob.glob(str(BASEDIR / "train/*/*")) @functools.lru_cache(None) def get_val_files(): return glob.glob(str(BASEDIR / "val/*/*")) -def image_load(fn): - import torchvision.transforms.functional as F - img = Image.open(fn).convert('RGB') - img = F.resize(img, 256, Image.BILINEAR) - img = F.center_crop(img, 224) - ret = np.array(img) - return ret +def image_resize(img, size, interpolation): + w, h = img.size + w_new = int((w / h) * size) if w > h else size + h_new = int((h / w) * size) if h > w else size + return img.resize([w_new, h_new], interpolation) -def iterate(bs=32, val=True, shuffle=True): - cir = get_imagenet_categories() - files = get_val_files() if val else get_train_files() - order = list(range(0, len(files))) - if DEBUG >= 1: print(f"imagenet size {len(order)}") - if shuffle: random.shuffle(order) - from multiprocessing import Pool - p = Pool(16) - for i in range(0, len(files), bs): - X = p.map(image_load, [files[i] for i in order[i:i+bs]]) - Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]] - yield (np.array(X), np.array(Y)) +def rand_flip(img): + if random.random() < 0.5: + img = np.flip(img, axis=1).copy() + return img -def fetch_batch(bs, val=False): - cir = get_imagenet_categories() - files = get_val_files() if val else get_train_files() - samp = np.random.randint(0, len(files), size=(bs)) - files = [files[i] for i in samp] - X = [image_load(x) for x in files] - Y = [cir[x.split("/")[0]] for x in files] - return np.array(X), np.array(Y) +def center_crop(img): + rescale = min(img.size) / 256 + crop_left = (img.width - 224 * rescale) / 2.0 + crop_top = (img.height - 224 * rescale) / 2.0 + img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left + 224 * rescale, crop_top + 224 * rescale)) + return img -if __name__ == "__main__": - X,Y = fetch_batch(64) - print(X.shape, Y) +# we don't use supplied imagenet bounding boxes, so scale min is just min_object_covered +# https://github.com/tensorflow/tensorflow/blob/e193d8ea7776ef5c6f5d769b6fb9c070213e737a/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc +def random_resized_crop(img, size, scale=(0.10, 1.0), ratio=(3/4, 4/3)): + w, h = img.size + area = w * h + # Crop + random_solution_found = False + for _ in range(10): + aspect_ratio = random.uniform(ratio[0], ratio[1]) + max_scale = min(min(w * aspect_ratio / h, h / aspect_ratio / w), scale[1]) + target_area = area * random.uniform(scale[0], max_scale) + + w_new = int(round(math.sqrt(target_area * aspect_ratio))) + h_new = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w_new <= w and 0 < h_new <= h: + crop_left = random.randint(0, w - w_new + 1) + crop_top = random.randint(0, h - h_new + 1) + + img = img.crop((crop_left, crop_top, crop_left + w_new, crop_top + h_new)) + random_solution_found = True + break + + if not random_solution_found: + # Center crop + img = center_crop(img) + else: + # Resize + img = img.resize([size, size], Image.BILINEAR) + + return img + +def preprocess_train(img): + img = random_resized_crop(img, 224) + img = rand_flip(np.array(img)) + return img diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 517f1ec9e1..8e3388b40c 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -3,20 +3,26 @@ from tinygrad.tensor import Tensor from tinygrad.nn.state import torch_load from tinygrad.helpers import fetch, get_child +# allow monkeypatching in layer implementations +BatchNorm = nn.BatchNorm2d +Conv2d = nn.Conv2d +Linear = nn.Linear + + class BasicBlock: expansion = 1 def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64): assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64" - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) + self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = BatchNorm(planes) + self.conv2 = Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False) + self.bn2 = BatchNorm(planes) self.downsample = [] if stride != 1 or in_planes != self.expansion*planes: self.downsample = [ - nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) + Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + BatchNorm(self.expansion*planes) ] def __call__(self, x): @@ -34,17 +40,17 @@ class Bottleneck: def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64): width = int(planes * (base_width / 64.0)) * groups # NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1 - self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes) + self.conv1 = Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False) + self.bn1 = BatchNorm(width) + self.conv2 = Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False) + self.bn2 = BatchNorm(width) + self.conv3 = Conv2d(width, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = BatchNorm(self.expansion*planes) self.downsample = [] if stride != 1 or in_planes != self.expansion*planes: self.downsample = [ - nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) + Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + BatchNorm(self.expansion*planes) ] def __call__(self, x): @@ -78,13 +84,13 @@ class ResNet: self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) - self.bn1 = nn.BatchNorm2d(64) + self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) + self.bn1 = BatchNorm(64) self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1) self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1) self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1) self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1) - self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None + self.fc = Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1): strides = [stride] + [1] * (num_blocks-1) @@ -112,7 +118,7 @@ class ResNet: if is_feature_only: features.append(out) if not is_feature_only: out = out.mean([2,3]) - out = self.fc(out).log_softmax() + out = self.fc(out) return out return features diff --git a/test/test_multitensor.py b/test/test_multitensor.py index e82256af26..754eba92ee 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -204,15 +204,45 @@ class TestMultiTensor(unittest.TestCase): fake_image_sharded = fake_image.shard((d0, d1), axis=0) m = ResNet18() m.load_from_pretrained() - real_output = m(fake_image).numpy() + real_output = m(fake_image).log_softmax().numpy() for p in get_parameters(m): p.shard_((d0, d1)).realize() GlobalCounters.reset() - shard_output = m(fake_image_sharded).realize() + shard_output = m(fake_image_sharded).log_softmax().realize() assert shard_output.lazydata.lbs[0].shape == (1, 1000) assert shard_output.lazydata.lbs[1].shape == (1, 1000) shard_output_np = shard_output.numpy() np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) + def test_data_parallel_resnet_train_step(self): + import sys, pathlib + sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) + from resnet import ResNet18 + from examples.mlperf.optimizers import LARS + + fake_image = Tensor.rand((2, 3, 224, 224)) + fake_image_sharded = fake_image.shard((d0, d1), axis=0) + labels = Tensor.randint(2, low=0, high=1000) + labels_sharded = labels.shard((d0, d1), axis=0) + + m = ResNet18() + optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params + + optimizer.zero_grad() + m.load_from_pretrained() + output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1) + output.backward() + grad = m.conv1.weight.grad.numpy() + + for p in get_parameters(m): p.shard_((d0, d1)).realize() + GlobalCounters.reset() + optimizer.zero_grad() + shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1) + assert shard_output.lazydata.axis is None + shard_output.backward() + shard_grad = m.conv1.weight.grad.numpy() + # sometimes there is zeros in these grads... why? + np.testing.assert_allclose(grad, shard_grad, atol=1e-6, rtol=1e-6) + def test_multi_tensor_jit_param(self): @TinyJit def jf(a, b) -> Tensor: