diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 48bbf5033a..d88d72f920 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -348,6 +348,7 @@ def train_retinanet(): def train_unet3d(): from examples.mlperf.losses import dice_ce_loss from examples.mlperf.metrics import dice_score + from examples.mlperf.dataloader import batch_load_unet3d from extra.models.unet3d import UNet3D from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, BASEDIR from tinygrad import Device, Tensor, GlobalCounters, Context @@ -423,6 +424,10 @@ def train_unet3d(): print(f"saving checkpoint to {fn}") safe_save(state_dict, fn) + def data_get(it): + x, y, cookie = next(it) + return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie + @TinyJit @Tensor.train() def train_step(model, x, y): @@ -467,33 +472,49 @@ def train_unet3d(): if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0: lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS) + train_dataloader = batch_load_unet3d(PREPROCESSED_DIR / "train", batch_size=BS, val=False, shuffle=True, seed=SEED) + it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK)) + i, proc = 0, data_get(it) + + prev_cookies = [] st = time.perf_counter() - for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1): - dt = time.perf_counter() - + while proc is not None: GlobalCounters.reset() - x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0) + loss, proc = train_step(model, proc[0], proc[1]), proc[2] - loss = train_step(model, x, y) pt = time.perf_counter() + if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued + try: + next_proc = data_get(it) + except StopIteration: + next_proc = None + + dt = time.perf_counter() + + device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" loss = loss.numpy().item() + cl = time.perf_counter() if BENCHMARK: step_times.append(cl - st) tqdm.write( - f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, " - f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS" + f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, " + f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, " + f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS" ) if WANDB: - wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st, - "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH}) + wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, + "train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH}) st = cl + prev_cookies.append(proc) + proc, next_proc = next_proc, None # return old cookie + i += 1 if i == BENCHMARK: median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds