Files
tinygrad/examples/mlperf/model_train.py
Francis Lata b7ce9a1530 UNet3D MLPerf (#3470)
* add training set transforms

* add DICE cross entropy loss

* convert pred and label to Tensor when calculating DICE score

* cleanups and allow train dataset batching

* fix DICE CE loss calculation

* jitted training step

* clean up DICE CE loss calculation

* initial support for sharding

* Revert "initial support for sharding"

This reverts commit e3670813b8.

* minor updates

* cleanup imports

* add support for sharding

* apply temp patch to try to avoid OOM

* revert cstyle changes

* add gradient acc

* hotfix

* add FP16 support

* add ability to train on smaller image sizes

* add support for saving and loading checkpoints + cleanup some various modes

* fix issue with using smaller patch size + update W&B logging

* disable LR_WARMUP_EPOCHS

* updates

* minor cleanups

* cleanup

* update order of transformations

* more cleanups

* realize loss

* cleanup

* more cleanup

* some cleanups

* add RAM usage

* minor cleanups

* add support for gradient accumulation

* cleanup imports

* minor updates to not use GA_STEPS

* remove FP16 option since it's available now globally

* update multi-GPU setup

* add timing logs for training loop

* go back to using existing dataloader and add ability to preprocess data to save time

* clean up optimization and re-enable JIT and multi-GPU support for training and evaluation

* free train and eval steps memory

* cleanups and scale batch size based on the number of GPUs

* fix GlobalCounters import

* fix seed

* fix W&B setup

* update batch size default size

* add back metric divergence check

* put back JIT on UNet3d eval

* move dataset preprocessing inside training code

* add test for dice_loss

* add config logging support to W&B and other cleanups

* change how default float is getting retrieved

* remove TinyJit import duplicate

* update config logging to W&B and remove JIT on eval_step

* no need for caching preprocessed data anymore

* fix how evaluation is ran and how often

* add support for LR scaling

* fix issue with gaussian being moved to scipy.signal.windows

* remove DICE loss unit test

* fix issue where loss isn't compatible with multiGPU

* add individual BEAM control for train and eval steps

* fix ndimage scipy import

* add BENCHMARK

* cleanups on BENCHMARK + fix on rand_flip augmentation during training

* cleanup train and eval BEAM envs

* add checkpointing support after every eval

* cleanup model_eval

* disable grad during eval

* use new preprocessing dataset mechanism

* remove unused import

* use training and inference_mode contexts

* start eval after benchmarking

* add data fetching time

* cleanup decorators

* more cleanups on training script

* add message during benchmarking mode

* realize when reassigning LR on scheduler and update default number of epochs

* add JIT on eval step

* remove JIT on eval_step

* add train dataloader for unet3d

* move checkpointing to be done after every epoch

* revert removal of JIT on unet3d inference

* save checkpoint if metric is not successful

* Revert "add train dataloader for unet3d"

This reverts commit c166d129df.

* Revert "Revert "add train dataloader for unet3d""

This reverts commit 36366c65d2.

* hotfix: seed was defaulting to a value of 0

* fix SEED value

* remove the usage of context managers for setting BEAM and going from training to inference

* support new stack API for calculating eval loss and metric

* Revert "remove the usage of context managers for setting BEAM and going from training to inference"

This reverts commit 2c0ba8d322.

* check training and test preprocessed folders separately

* clean up imports and log FUSE_CONV_BW

* use train and val preprocessing constants

* add kits19 dataset setup script

* update to use the new test decorator for disabling grad

* update kits19 dataset setup script

* add docs on how to train the model

* set default value for BASEDIR

* add detailed instruction about BASEDIR usage

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-09-10 04:37:28 -04:00

910 lines
40 KiB
Python

import os, time, math, functools
from pathlib import Path
from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from extra.lr_scheduler import LRSchedulerGroup
from examples.mlperf.helpers import get_training_state, load_training_state
def train_resnet():
from extra.models import resnet
from examples.mlperf.dataloader import batch_load_resnet
from extra.datasets.imagenet import get_train_files, get_val_files
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from examples.mlperf.initializers import Conv2dHeNormal, Linear
from examples.hlb_cifar10 import UnsyncedBatchNorm
config = {}
seed = config["seed"] = getenv("SEED", 42)
Tensor.manual_seed(seed) # seed for weight initialization
INITMLPERF = getenv("INITMLPERF")
RUNMLPERF = getenv("RUNMLPERF")
if getenv("LOGMLPERF"):
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename=f"result_{seed}.txt")
mllog.config(root_dir=Path(__file__).parents[3].as_posix()) # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py"
MLLOGGER = mllog.get_mllogger()
if INITMLPERF:
# common.yaml
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
# closed_common.yaml
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RESNET)
diskcache_clear()
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
MLLOGGER.start(key=mllog_constants.INIT_START)
if RUNMLPERF:
MLLOGGER.start(key=mllog_constants.RUN_START)
MLLOGGER.event(key=mllog_constants.SEED, value=seed)
else:
MLLOGGER = None
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"training on {GPUS}")
for x in GPUS: Device[x]
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
# ** model definition and initializers **
num_classes = 1000
resnet.Conv2d = Conv2dHeNormal
resnet.Linear = Linear
if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
model = resnet.ResNet50(num_classes)
# shard weights and initialize in order
for k, x in get_state_dict(model).items():
if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
x.realize().shard_(GPUS, axis=0)
else:
x.realize().to_(GPUS)
parameters = get_parameters(model)
# ** hyperparameters **
epochs = config["epochs"] = getenv("EPOCHS", 37)
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
base_lr = config["base_lr"] = getenv("LR", 7.2 * (BS/1536))
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
decay = config["decay"] = getenv("DECAY", 2e-4)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 128.0 if dtypes.default_float == dtypes.float16 else 1.0)
target, achieved = getenv("TARGET", 0.759), False
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
eval_freq = getenv("EVAL_FREQ", 1)
steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["BEAM"] = BEAM.value
config["TRAIN_BEAM"] = TRAIN_BEAM
config["EVAL_BEAM"] = EVAL_BEAM
config["WINO"] = WINO.value
config["SYNCBN"] = getenv("SYNCBN")
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
print(f"training with batch size {BS} for {epochs} epochs")
# log mlperf hparams
if MLLOGGER:
if RUNMLPERF:
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=BS)
from extra.datasets.imagenet import get_train_files, get_val_files
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=len(get_train_files()))
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=len(get_val_files()))
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="lars")
assert scheduler.initial_lr == scheduler_skip.initial_lr
assert scheduler.end_lr == scheduler_skip.end_lr
assert scheduler.power == scheduler_skip.power
MLLOGGER.event(key=mllog_constants.LARS_OPT_BASE_LEARNING_RATE, value=scheduler.initial_lr)
MLLOGGER.event(key=mllog_constants.LARS_OPT_END_LR, value=scheduler.end_lr)
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_POLY_POWER, value=scheduler.power)
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_STEPS, value=epochs)
MLLOGGER.event(key=mllog_constants.LARS_EPSILON, value=0) # does not support epsilon != 0
MLLOGGER.event(key=mllog_constants.LARS_OPT_LEARNING_RATE_WARMUP_EPOCHS, value=lr_warmup_epochs)
MLLOGGER.event(key=mllog_constants.LARS_OPT_MOMENTUM, value=optimizer.momentum)
MLLOGGER.event(key=mllog_constants.LARS_OPT_WEIGHT_DECAY, value=optimizer.wd)
# ** resume from checkpointing **
start_epoch = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
print(f"resuming from {ckpt} at epoch {start_epoch}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args)
BENCHMARK = getenv("BENCHMARK")
# ** jitted steps **
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
# mlperf reference resnet does not divide by input_std for some reason
# input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
@TinyJit
def train_step(X, Y):
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
(loss * loss_scaler).backward()
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
optimizer_group.step()
scheduler_group.step()
return loss.realize(), top_1.realize()
@TinyJit
def eval_step(X, Y):
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
return loss.realize(), top_1.realize()
def fake_data_get(batch_size):
x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
y = [0] * batch_size
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
# ** epoch loop **
step_times = []
for e in range(start_epoch, epochs):
# ** train loop **
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
Tensor.training = True
BEAM.value = TRAIN_BEAM
if INITMLPERF:
i, proc = 0, fake_data_get(BS)
else:
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, data_get(it)
prev_cookies = []
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
(loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
pt = time.perf_counter()
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
if INITMLPERF:
next_proc = fake_data_get(BS)
else:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
top_1_acc = top_1 / sum(yi != -1 for yi in y)
cl = time.perf_counter()
if BENCHMARK:
step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
st = cl
prev_cookies.append(proc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
assert not math.isnan(loss)
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
# if we are doing beam search, run the first eval too
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e+1, metadata=dict(epoch_num=e+1))
# ** eval loop **
# always eval for epoch >= 33 to stop the clock as soon as eval target hits, it can converge in epoch in [33, 37]
if steps_in_val_epoch > 0 and ((e + 1 - eval_start_epoch) % eval_freq == 0 or e + 1 >= 33):
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=e+1, metadata=dict(epoch_num=e+1))
if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
eval_times = []
eval_loss = 0.0
eval_top_1 = 0
eval_num_samples = 0
Tensor.training = False
BEAM.value = EVAL_BEAM
if INITMLPERF:
i, proc = 0, fake_data_get(EVAL_BS)
else:
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
i, proc = 0, data_get(it)
prev_cookies = []
while proc is not None:
GlobalCounters.reset()
st = time.time()
(loss, top_1), y, proc = eval_step(proc[0], proc[1]), proc[2], proc[3] # drop inputs, keep cookie
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
if INITMLPERF:
next_proc = fake_data_get(EVAL_BS)
else:
next_proc = data_get(it)
except StopIteration:
next_proc = None
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
num_samples = sum(yi != -1 for yi in y)
eval_loss += loss * num_samples
eval_top_1 += top_1
eval_num_samples += num_samples
prev_cookies.append(proc)
proc, next_proc = next_proc, None
i += 1
if i == BENCHMARK:
# assume INITMLPERF has BENCHMARK set
if MLLOGGER and INITMLPERF:
MLLOGGER.event(key=mllog_constants.INIT_STOP)
return
et = time.time()
eval_times.append(et - st)
if getenv("RESET_STEP", 1): eval_step.reset()
if not BENCHMARK:
assert eval_num_samples == len(get_val_files()), f"eval sample count mismatched. {eval_num_samples=} != {len(get_val_files())}"
total_loss = eval_loss / eval_num_samples
total_top_1 = eval_top_1 / eval_num_samples
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
if WANDB:
wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=total_top_1, metadata=dict(epoch_num=e+1))
MLLOGGER.event(key=mllog_constants.EVAL_STOP, value=e+1, metadata=dict(epoch_num=e+1))
# save model if achieved target
if not achieved and total_top_1 >= target:
# stop once achieve the target
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
fn = f"./ckpts/resnet50_{seed}.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
achieved = True
break
# checkpoint every time we eval
if getenv("CKPT"):
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
"""
Trains the UNet3D model.
Instructions:
1) Run the following script from the root folder of `tinygrad`:
```./examples/mlperf/scripts/setup_kits19_dataset.sh```
Optionally, `BASEDIR` can be set to download and process the dataset at a specific location:
```BASEDIR=<folder_path> ./examples/mlperf/scripts/setup_kits19_dataset.sh```
2) To start training the model, run the following:
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
"""
from examples.mlperf.losses import dice_ce_loss
from examples.mlperf.metrics import dice_score
from examples.mlperf.dataloader import batch_load_unet3d
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
from tinygrad import Context
from tinygrad.nn.optim import SGD
from math import ceil
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
for x in GPUS: Device[x]
TARGET_METRIC = 0.908
NUM_EPOCHS = getenv("NUM_EPOCHS", 4000)
BS = getenv("BS", 1 * len(GPUS))
LR = getenv("LR", 2.0 * (BS / 28))
LR_WARMUP_EPOCHS = getenv("LR_WARMUP_EPOCHS", 1000)
LR_WARMUP_INIT_LR = getenv("LR_WARMUP_INIT_LR", 0.0001)
WANDB = getenv("WANDB")
PROJ_NAME = getenv("PROJ_NAME", "tinygrad_unet3d_mlperf")
SEED = getenv("SEED", -1) if getenv("SEED", -1) >= 0 else None
TRAIN_DATASET_SIZE, VAL_DATASET_SIZE = len(get_train_files()), len(get_val_files())
SAMPLES_PER_EPOCH = TRAIN_DATASET_SIZE // BS
START_EVAL_AT = getenv("START_EVAL_AT", ceil(1000 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
EVALUATE_EVERY = getenv("EVALUATE_EVERY", ceil(20 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
TRAIN_BEAM, EVAL_BEAM = getenv("TRAIN_BEAM", BEAM.value), getenv("EVAL_BEAM", BEAM.value)
BENCHMARK = getenv("BENCHMARK")
CKPT = getenv("CKPT")
config = {
"num_epochs": NUM_EPOCHS,
"batch_size": BS,
"learning_rate": LR,
"learning_rate_warmup_epochs": LR_WARMUP_EPOCHS,
"learning_rate_warmup_init": LR_WARMUP_INIT_LR,
"start_eval_at": START_EVAL_AT,
"evaluate_every": EVALUATE_EVERY,
"train_beam": TRAIN_BEAM,
"eval_beam": EVAL_BEAM,
"wino": WINO.value,
"fuse_conv_bw": FUSE_CONV_BW.value,
"gpus": GPUS,
"default_float": dtypes.default_float.name
}
if WANDB:
try:
import wandb
except ImportError:
raise "Need to install wandb to use it"
if SEED is not None:
config["seed"] = SEED
Tensor.manual_seed(SEED)
model = UNet3D()
params = get_parameters(model)
for p in params: p.realize().to_(GPUS)
optim = SGD(params, lr=LR, momentum=0.9, nesterov=True)
def lr_warm_up(optim, init_lr, lr, current_epoch, warmup_epochs):
scale = current_epoch / warmup_epochs
optim.lr.assign(Tensor([init_lr + (lr - init_lr) * scale], device=GPUS)).realize()
def save_checkpoint(state_dict, fn):
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
print(f"saving checkpoint to {fn}")
safe_save(state_dict, fn)
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit
@Tensor.train()
def train_step(model, x, y):
optim.zero_grad()
y_hat = model(x)
loss = dice_ce_loss(y_hat, y)
loss.backward()
optim.step()
return loss.realize()
@Tensor.train(mode=False)
@Tensor.test()
def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
loss = dice_ce_loss(y_hat, y)
score = dice_score(y_hat, y)
return loss.realize(), score.realize()
if WANDB: wandb.init(config=config, project=PROJ_NAME)
step_times, start_epoch = [], 1
is_successful, diverged = False, False
start_eval_at, evaluate_every = 1 if BENCHMARK else START_EVAL_AT, 1 if BENCHMARK else EVALUATE_EVERY
next_eval_at = start_eval_at
print(f"Training on {GPUS}")
if BENCHMARK: print("Benchmarking UNet3D")
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
if not TRAIN_PREPROCESSED_DIR.exists(): preprocess_dataset(get_train_files(), TRAIN_PREPROCESSED_DIR, False)
if not VAL_PREPROCESSED_DIR.exists(): preprocess_dataset(get_val_files(), VAL_PREPROCESSED_DIR, True)
for epoch in range(1, NUM_EPOCHS + 1):
with Context(BEAM=TRAIN_BEAM):
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
train_dataloader = batch_load_unet3d(TRAIN_PREPROCESSED_DIR, batch_size=BS, val=False, shuffle=True, seed=SEED)
it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK))
i, proc = 0, data_get(it)
prev_cookies = []
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
loss, proc = train_step(model, proc[0], proc[1]), proc[2]
pt = time.perf_counter()
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
if WANDB:
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt,
"train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
st = cl
prev_cookies.append(proc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
return
with Context(BEAM=EVAL_BEAM):
if epoch == next_eval_at:
next_eval_at += evaluate_every
eval_loss = []
scores = []
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=VAL_PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
eval_loss_value, score = eval_step(model, x, y)
eval_loss.append(eval_loss_value)
scores.append(score)
scores = Tensor.mean(Tensor.stack(*scores, dim=0), axis=0).numpy()
eval_loss = Tensor.mean(Tensor.stack(*eval_loss, dim=0), axis=0).numpy()
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
mean_dice = (l2_dice + l1_dice) / 2
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
if WANDB:
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
if mean_dice >= TARGET_METRIC:
is_successful = True
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
elif mean_dice < 1e-6:
print("Model diverging. Aborting.")
diverged = True
if not is_successful and CKPT:
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
save_checkpoint(get_state_dict(model), fn)
if is_successful or diverged:
break
def train_rnnt():
# TODO: RNN-T
pass
@TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
optimizer.zero_grad()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
(loss * loss_scaler).backward()
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt()
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
optimizer.step()
scheduler.step()
return loss.realize()
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
return {
"masked_lm_accuracy": masked_lm_accuracy.realize(),
"next_sentence_accuracy": seq_relationship_accuracy.realize(),
"masked_lm_loss": masked_lm_loss.realize(),
"next_sentence_loss": next_sentence_loss.realize()
}
def train_bert():
# NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert, get_fake_data_bert
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {}
BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"training on {GPUS}")
for x in GPUS: Device[x]
seed = config["seed"] = getenv("SEED", 12345)
INITMLPERF = getenv("INITMLPERF")
RUNMLPERF = getenv("RUNMLPERF")
BENCHMARK = getenv("BENCHMARK")
if getenv("LOGMLPERF"):
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename="bert.log")
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
MLLOGGER = mllog.get_mllogger()
MLLOGGER.logger.propagate = False
if INITMLPERF:
assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
diskcache_clear()
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
if RUNMLPERF:
MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
else:
MLLOGGER = None
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0001 * math.sqrt(BS/66))
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3000000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**13 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["DECAY"] = getenv("DECAY", 0.01)
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
target, achieved = getenv("TARGET", 0.72), False
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["DISABLE_DROPOUT"] = getenv("DISABLE_DROPOUT", 0)
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)
parameters = get_parameters(model)
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
# ** Log run config **
for key, value in config.items(): print(f'HParam: "{key}": {value}')
# ** Optimizer **
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
optimizer_wd = LAMB(parameters, lr=max_lr, eps=epsilon, weight_decay=decay, adam=False)
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=epsilon, weight_decay=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
# ** LR scheduler **
scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
# log mlperf hparams
if MLLOGGER:
if RUNMLPERF:
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["GLOBAL_BATCH_SIZE"])
MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=512)
MLLOGGER.event(key="max_predictions_per_seq", value=76)
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="LAMB")
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["OPT_BASE_LEARNING_RATE"])
MLLOGGER.event(key=mllog_constants.OPT_LAMB_WEIGHT_DECAY, value=config["DECAY"])
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_1, value=optimizer_wd.b1)
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_2, value=optimizer_wd.b2)
MLLOGGER.event(key=mllog_constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=config["POLY_POWER"])
MLLOGGER.event(key=mllog_constants.OPT_LAMB_EPSILON, value=config["EPSILON"])
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
MLLOGGER.event(key='start_warmup_step', value=0)
MLLOGGER.event(key='opt_learning_rate_training_steps', value=config["TRAIN_STEPS"])
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["EVAL_BS"] * config["MAX_EVAL_STEPS"])
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["GLOBAL_BATCH_SIZE"] * config["TRAIN_STEPS"])
# ** resume from checkpointing **
start_step = 0
previous_step = None
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_step = int(scheduler_wd.epoch_counter.numpy().item())
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
if not INITMLPERF:
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
for _ in range(start_step): next(train_it) # Fast forward
step_times = []
# ** train loop **
wc_start = time.perf_counter()
if INITMLPERF:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
else:
i, train_data = start_step, get_data_bert(GPUS, train_it)
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
pt = time.perf_counter()
try:
if INITMLPERF:
next_data = get_fake_data_bert(GPUS, BS)
else:
next_data = get_data_bert(GPUS, train_it)
except StopIteration:
next_data = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
train_data, next_data = next_data, None
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * train_steps / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
# ** eval loop **
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
train_step_bert.reset()
eval_lm_losses = []
eval_clsf_losses = []
eval_lm_accs = []
eval_clsf_accs = []
eval_times = []
Tensor.training = False
BEAM.value = EVAL_BEAM
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
if INITMLPERF:
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
else:
eval_data = get_data_bert(GPUS, eval_it)
GlobalCounters.reset()
st = time.time()
eval_result: dict[str, Tensor] = eval_step_bert(model,
eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item()
lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item()
eval_lm_losses.append(lm_loss)
eval_clsf_losses.append(clsf_loss)
eval_lm_accs.append(lm_acc)
eval_clsf_accs.append(clsf_acc)
et = time.time()
eval_times.append(et - st)
if BENCHMARK and j == BENCHMARK:
# assume INITMLPERF has BENCHMARK set
if MLLOGGER and INITMLPERF:
MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None)
return
eval_step_bert.reset()
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)
avg_clsf_acc = sum(eval_clsf_accs) / len(eval_clsf_accs)
avg_fw_time = sum(eval_times) / len(eval_times)
results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
tqdm.write(results)
if WANDB:
wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i, metadata={"epoch_count": 1, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": 1, "masked_lm_accuracy": avg_lm_acc})
# save model if achieved target
if not achieved and avg_lm_acc >= target:
wc_end = time.perf_counter()
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/bert-large.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
total_seconds = wc_end - wc_start
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = total_seconds % 60
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
achieved = True
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
# stop once hitting the target
break
if getenv("CKPT", 1) and i % save_ckpt_freq == 0:
if MLLOGGER and RUNMLPERF:
if previous_step:
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
else:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
while len(ckpt_files) > keep_ckpt_amount:
last = ckpt_files.pop(0)
print(f"Removing old ckpt {last}")
os.remove(os.path.join(ckpt_dir, last))
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": i * BS, "step_num": i, "first_step_num": i+1})
previous_step = i
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()