mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
python time 45ms -> 9ms, it was spending time to schedule the shard also init bert data on CLANG since it's from numpy, so we don't create the tensor on default device then shard into GPUS
928 lines
41 KiB
Python
928 lines
41 KiB
Python
import os, time, math, functools
|
|
from pathlib import Path
|
|
import multiprocessing
|
|
|
|
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
|
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
|
|
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
|
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
|
|
|
from extra.lr_scheduler import LRSchedulerGroup
|
|
from examples.mlperf.helpers import get_training_state, load_training_state
|
|
# TODO: fix benchmark logging and use tinygrad tqdm
|
|
from tqdm import tqdm
|
|
|
|
def train_resnet():
|
|
from extra.models import resnet
|
|
from examples.mlperf.dataloader import batch_load_resnet
|
|
from extra.datasets.imagenet import get_train_files, get_val_files
|
|
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
|
from examples.mlperf.initializers import Conv2dHeNormal, Linear
|
|
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
|
|
|
config = {}
|
|
seed = config["seed"] = getenv("SEED", 42)
|
|
Tensor.manual_seed(seed) # seed for weight initialization
|
|
|
|
INITMLPERF = getenv("INITMLPERF")
|
|
RUNMLPERF = getenv("RUNMLPERF")
|
|
if getenv("LOGMLPERF"):
|
|
from mlperf_logging import mllog
|
|
import mlperf_logging.mllog.constants as mllog_constants
|
|
mllog.config(filename=f"result_resnet_{seed}.txt")
|
|
mllog.config(root_dir=Path(__file__).parents[3].as_posix()) # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py"
|
|
MLLOGGER = mllog.get_mllogger()
|
|
if INITMLPERF:
|
|
# common.yaml
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
|
|
# closed_common.yaml
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RESNET)
|
|
diskcache_clear()
|
|
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
|
|
MLLOGGER.start(key=mllog_constants.INIT_START)
|
|
if RUNMLPERF:
|
|
MLLOGGER.start(key=mllog_constants.RUN_START)
|
|
MLLOGGER.event(key=mllog_constants.SEED, value=seed)
|
|
else:
|
|
MLLOGGER = None
|
|
|
|
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
|
print(f"training on {GPUS}")
|
|
for x in GPUS: Device[x]
|
|
|
|
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
|
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
|
|
|
# ** model definition and initializers **
|
|
num_classes = 1000
|
|
resnet.Conv2d = Conv2dHeNormal
|
|
resnet.Linear = Linear
|
|
if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
|
|
model = resnet.ResNet50(num_classes)
|
|
|
|
# shard weights and initialize in order
|
|
for k, x in get_state_dict(model).items():
|
|
if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
|
|
x.realize().shard_(GPUS, axis=0)
|
|
else:
|
|
x.realize().to_(GPUS)
|
|
parameters = get_parameters(model)
|
|
|
|
# ** hyperparameters **
|
|
epochs = config["epochs"] = getenv("EPOCHS", 37)
|
|
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
|
|
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
|
|
base_lr = config["base_lr"] = getenv("LR", 7.2 * (BS/1536))
|
|
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
|
|
decay = config["decay"] = getenv("DECAY", 2e-4)
|
|
|
|
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 256.0 if dtypes.default_float == dtypes.float16 else 1.0)
|
|
|
|
target, achieved = getenv("TARGET", 0.759), False
|
|
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
|
|
eval_freq = getenv("EVAL_FREQ", 1)
|
|
|
|
steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS)
|
|
steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
|
|
|
|
config["DEFAULT_FLOAT"] = dtypes.default_float.name
|
|
config["BEAM"] = BEAM.value
|
|
config["TRAIN_BEAM"] = TRAIN_BEAM
|
|
config["EVAL_BEAM"] = EVAL_BEAM
|
|
config["WINO"] = WINO.value
|
|
config["SYNCBN"] = getenv("SYNCBN")
|
|
|
|
# ** Optimizer **
|
|
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
|
|
parameters = [x for x in parameters if x not in set(skip_list)]
|
|
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
|
|
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
|
|
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
|
|
|
|
# ** LR scheduler **
|
|
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
|
|
train_steps=epochs * steps_in_train_epoch,
|
|
warmup=lr_warmup_epochs * steps_in_train_epoch)
|
|
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
|
|
train_steps=epochs * steps_in_train_epoch,
|
|
warmup=lr_warmup_epochs * steps_in_train_epoch)
|
|
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
|
|
print(f"training with batch size {BS} for {epochs} epochs")
|
|
|
|
# log mlperf hparams
|
|
if MLLOGGER:
|
|
if RUNMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=BS)
|
|
from extra.datasets.imagenet import get_train_files, get_val_files
|
|
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=len(get_train_files()))
|
|
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=len(get_val_files()))
|
|
|
|
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
|
|
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="lars")
|
|
assert scheduler.initial_lr == scheduler_skip.initial_lr
|
|
assert scheduler.end_lr == scheduler_skip.end_lr
|
|
assert scheduler.power == scheduler_skip.power
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_BASE_LEARNING_RATE, value=scheduler.initial_lr)
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_END_LR, value=scheduler.end_lr)
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_POLY_POWER, value=scheduler.power)
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_STEPS, value=epochs)
|
|
MLLOGGER.event(key=mllog_constants.LARS_EPSILON, value=0) # does not support epsilon != 0
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_LEARNING_RATE_WARMUP_EPOCHS, value=lr_warmup_epochs)
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_MOMENTUM, value=optimizer.momentum)
|
|
MLLOGGER.event(key=mllog_constants.LARS_OPT_WEIGHT_DECAY, value=optimizer.wd)
|
|
|
|
# ** resume from checkpointing **
|
|
start_epoch = 0
|
|
if ckpt:=getenv("RESUME", ""):
|
|
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
|
|
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
|
|
print(f"resuming from {ckpt} at epoch {start_epoch}")
|
|
|
|
# ** init wandb **
|
|
WANDB = getenv("WANDB")
|
|
if WANDB:
|
|
import wandb
|
|
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
|
wandb.init(config=config, **wandb_args)
|
|
|
|
BENCHMARK = getenv("BENCHMARK")
|
|
|
|
# ** jitted steps **
|
|
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
|
|
# mlperf reference resnet does not divide by input_std for some reason
|
|
# input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
|
|
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
|
|
@TinyJit
|
|
def train_step(X, Y):
|
|
optimizer_group.zero_grad()
|
|
X = normalize(X)
|
|
out = model.forward(X)
|
|
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
|
top_1 = (out.argmax(-1) == Y).sum()
|
|
(loss * loss_scaler).backward()
|
|
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
|
|
optimizer_group.step()
|
|
scheduler_group.step()
|
|
return loss.realize(), top_1.realize()
|
|
|
|
@TinyJit
|
|
def eval_step(X, Y):
|
|
X = normalize(X)
|
|
out = model.forward(X)
|
|
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
|
top_1 = (out.argmax(-1) == Y).sum()
|
|
return loss.realize(), top_1.realize()
|
|
|
|
def fake_data_get(batch_size):
|
|
x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
|
|
y = [0] * batch_size
|
|
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
|
|
|
|
def data_get(it):
|
|
x, y, cookie = next(it)
|
|
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
|
|
|
|
# ** epoch loop **
|
|
step_times = []
|
|
for e in range(start_epoch, epochs):
|
|
# ** train loop **
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
|
|
Tensor.training = True
|
|
BEAM.value = TRAIN_BEAM
|
|
|
|
if INITMLPERF:
|
|
i, proc = 0, fake_data_get(BS)
|
|
else:
|
|
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
|
|
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
|
i, proc = 0, data_get(it)
|
|
|
|
prev_cookies = []
|
|
st = time.perf_counter()
|
|
while proc is not None:
|
|
GlobalCounters.reset()
|
|
(loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
|
|
|
|
pt = time.perf_counter()
|
|
|
|
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
|
try:
|
|
if INITMLPERF:
|
|
next_proc = fake_data_get(BS)
|
|
else:
|
|
next_proc = data_get(it)
|
|
except StopIteration:
|
|
next_proc = None
|
|
|
|
dt = time.perf_counter()
|
|
|
|
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
|
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
|
|
top_1_acc = top_1 / sum(yi != -1 for yi in y)
|
|
|
|
cl = time.perf_counter()
|
|
if BENCHMARK:
|
|
step_times.append(cl - st)
|
|
|
|
tqdm.write(
|
|
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
|
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
|
|
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
|
if WANDB:
|
|
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
|
|
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
|
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
|
|
|
|
st = cl
|
|
prev_cookies.append(proc)
|
|
proc, next_proc = next_proc, None # return old cookie
|
|
i += 1
|
|
|
|
if i == BENCHMARK:
|
|
assert not math.isnan(loss)
|
|
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
|
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
|
|
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
|
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
|
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
|
# if we are doing beam search, run the first eval too
|
|
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
|
return
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e+1, metadata=dict(epoch_num=e+1))
|
|
|
|
# ** eval loop **
|
|
# always eval for epoch >= 33 to stop the clock as soon as eval target hits, it can converge in epoch in [33, 37]
|
|
if steps_in_val_epoch > 0 and ((e + 1 - eval_start_epoch) % eval_freq == 0 or e + 1 >= 33):
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.start(key=mllog_constants.EVAL_START, value=e+1, metadata=dict(epoch_num=e+1))
|
|
if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
|
|
eval_times = []
|
|
eval_loss = 0.0
|
|
eval_top_1 = 0
|
|
eval_num_samples = 0
|
|
Tensor.training = False
|
|
BEAM.value = EVAL_BEAM
|
|
|
|
if INITMLPERF:
|
|
i, proc = 0, fake_data_get(EVAL_BS)
|
|
else:
|
|
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
|
|
i, proc = 0, data_get(it)
|
|
|
|
prev_cookies = []
|
|
while proc is not None:
|
|
GlobalCounters.reset()
|
|
st = time.time()
|
|
|
|
(loss, top_1), y, proc = eval_step(proc[0], proc[1]), proc[2], proc[3] # drop inputs, keep cookie
|
|
|
|
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
|
try:
|
|
if INITMLPERF:
|
|
next_proc = fake_data_get(EVAL_BS)
|
|
else:
|
|
next_proc = data_get(it)
|
|
except StopIteration:
|
|
next_proc = None
|
|
|
|
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
|
|
num_samples = sum(yi != -1 for yi in y)
|
|
eval_loss += loss * num_samples
|
|
eval_top_1 += top_1
|
|
eval_num_samples += num_samples
|
|
prev_cookies.append(proc)
|
|
proc, next_proc = next_proc, None
|
|
i += 1
|
|
if i == BENCHMARK:
|
|
# assume INITMLPERF has BENCHMARK set
|
|
if MLLOGGER and INITMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.INIT_STOP)
|
|
return
|
|
|
|
et = time.time()
|
|
eval_times.append(et - st)
|
|
|
|
if getenv("RESET_STEP", 1): eval_step.reset()
|
|
if not BENCHMARK:
|
|
assert eval_num_samples == len(get_val_files()), f"eval sample count mismatched. {eval_num_samples=} != {len(get_val_files())}"
|
|
total_loss = eval_loss / eval_num_samples
|
|
total_top_1 = eval_top_1 / eval_num_samples
|
|
total_fw_time = sum(eval_times) / len(eval_times)
|
|
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
|
|
if WANDB:
|
|
wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=total_top_1, metadata=dict(epoch_num=e+1))
|
|
MLLOGGER.event(key=mllog_constants.EVAL_STOP, value=e+1, metadata=dict(epoch_num=e+1))
|
|
|
|
# save model if achieved target
|
|
if not achieved and total_top_1 >= target:
|
|
# stop once achieve the target
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
|
|
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
|
|
fn = f"./ckpts/resnet50_{seed}.safe"
|
|
safe_save(get_state_dict(model), fn)
|
|
print(f" *** Model saved to {fn} ***")
|
|
achieved = True
|
|
break
|
|
|
|
# checkpoint every time we eval
|
|
if getenv("CKPT"):
|
|
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
|
|
if WANDB and wandb.run is not None:
|
|
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
|
|
else:
|
|
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
|
|
print(f"saving ckpt to {fn}")
|
|
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
|
|
|
def train_retinanet():
|
|
# TODO: Retinanet
|
|
pass
|
|
|
|
def train_unet3d():
|
|
"""
|
|
Trains the UNet3D model.
|
|
|
|
Instructions:
|
|
1) Run the following script from the root folder of `tinygrad`:
|
|
```./examples/mlperf/scripts/setup_kits19_dataset.sh```
|
|
|
|
Optionally, `BASEDIR` can be set to download and process the dataset at a specific location:
|
|
```BASEDIR=<folder_path> ./examples/mlperf/scripts/setup_kits19_dataset.sh```
|
|
|
|
2) To start training the model, run the following:
|
|
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
|
|
"""
|
|
from examples.mlperf.losses import dice_ce_loss
|
|
from examples.mlperf.metrics import dice_score
|
|
from examples.mlperf.dataloader import batch_load_unet3d
|
|
from extra.models.unet3d import UNet3D
|
|
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
|
|
from tinygrad import Context
|
|
from tinygrad.nn.optim import SGD
|
|
from math import ceil
|
|
|
|
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
|
for x in GPUS: Device[x]
|
|
|
|
TARGET_METRIC = 0.908
|
|
NUM_EPOCHS = getenv("NUM_EPOCHS", 4000)
|
|
BS = getenv("BS", 1 * len(GPUS))
|
|
LR = getenv("LR", 2.0 * (BS / 28))
|
|
LR_WARMUP_EPOCHS = getenv("LR_WARMUP_EPOCHS", 1000)
|
|
LR_WARMUP_INIT_LR = getenv("LR_WARMUP_INIT_LR", 0.0001)
|
|
WANDB = getenv("WANDB")
|
|
PROJ_NAME = getenv("PROJ_NAME", "tinygrad_unet3d_mlperf")
|
|
SEED = getenv("SEED", -1) if getenv("SEED", -1) >= 0 else None
|
|
TRAIN_DATASET_SIZE, VAL_DATASET_SIZE = len(get_train_files()), len(get_val_files())
|
|
SAMPLES_PER_EPOCH = TRAIN_DATASET_SIZE // BS
|
|
START_EVAL_AT = getenv("START_EVAL_AT", ceil(1000 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
|
|
EVALUATE_EVERY = getenv("EVALUATE_EVERY", ceil(20 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
|
|
TRAIN_BEAM, EVAL_BEAM = getenv("TRAIN_BEAM", BEAM.value), getenv("EVAL_BEAM", BEAM.value)
|
|
BENCHMARK = getenv("BENCHMARK")
|
|
CKPT = getenv("CKPT")
|
|
|
|
config = {
|
|
"num_epochs": NUM_EPOCHS,
|
|
"batch_size": BS,
|
|
"learning_rate": LR,
|
|
"learning_rate_warmup_epochs": LR_WARMUP_EPOCHS,
|
|
"learning_rate_warmup_init": LR_WARMUP_INIT_LR,
|
|
"start_eval_at": START_EVAL_AT,
|
|
"evaluate_every": EVALUATE_EVERY,
|
|
"train_beam": TRAIN_BEAM,
|
|
"eval_beam": EVAL_BEAM,
|
|
"wino": WINO.value,
|
|
"fuse_conv_bw": FUSE_CONV_BW.value,
|
|
"gpus": GPUS,
|
|
"default_float": dtypes.default_float.name
|
|
}
|
|
|
|
if WANDB:
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
raise "Need to install wandb to use it"
|
|
|
|
if SEED is not None:
|
|
config["seed"] = SEED
|
|
Tensor.manual_seed(SEED)
|
|
|
|
model = UNet3D()
|
|
params = get_parameters(model)
|
|
|
|
for p in params: p.realize().to_(GPUS)
|
|
|
|
optim = SGD(params, lr=LR, momentum=0.9, nesterov=True)
|
|
|
|
def lr_warm_up(optim, init_lr, lr, current_epoch, warmup_epochs):
|
|
scale = current_epoch / warmup_epochs
|
|
optim.lr.assign(Tensor([init_lr + (lr - init_lr) * scale], device=GPUS)).realize()
|
|
|
|
def save_checkpoint(state_dict, fn):
|
|
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
|
|
print(f"saving checkpoint to {fn}")
|
|
safe_save(state_dict, fn)
|
|
|
|
def data_get(it):
|
|
x, y, cookie = next(it)
|
|
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
|
|
|
@TinyJit
|
|
@Tensor.train()
|
|
def train_step(model, x, y):
|
|
optim.zero_grad()
|
|
|
|
y_hat = model(x)
|
|
loss = dice_ce_loss(y_hat, y)
|
|
|
|
loss.backward()
|
|
optim.step()
|
|
return loss.realize()
|
|
|
|
@Tensor.train(mode=False)
|
|
@Tensor.test()
|
|
def eval_step(model, x, y):
|
|
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
|
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
|
|
loss = dice_ce_loss(y_hat, y)
|
|
score = dice_score(y_hat, y)
|
|
return loss.realize(), score.realize()
|
|
|
|
if WANDB: wandb.init(config=config, project=PROJ_NAME)
|
|
|
|
step_times, start_epoch = [], 1
|
|
is_successful, diverged = False, False
|
|
start_eval_at, evaluate_every = 1 if BENCHMARK else START_EVAL_AT, 1 if BENCHMARK else EVALUATE_EVERY
|
|
next_eval_at = start_eval_at
|
|
|
|
print(f"Training on {GPUS}")
|
|
|
|
if BENCHMARK: print("Benchmarking UNet3D")
|
|
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
|
|
|
|
if not TRAIN_PREPROCESSED_DIR.exists(): preprocess_dataset(get_train_files(), TRAIN_PREPROCESSED_DIR, False)
|
|
if not VAL_PREPROCESSED_DIR.exists(): preprocess_dataset(get_val_files(), VAL_PREPROCESSED_DIR, True)
|
|
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
with Context(BEAM=TRAIN_BEAM):
|
|
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
|
|
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
|
|
|
|
train_dataloader = batch_load_unet3d(TRAIN_PREPROCESSED_DIR, batch_size=BS, val=False, shuffle=True, seed=SEED)
|
|
it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK))
|
|
i, proc = 0, data_get(it)
|
|
|
|
prev_cookies = []
|
|
st = time.perf_counter()
|
|
|
|
while proc is not None:
|
|
GlobalCounters.reset()
|
|
|
|
loss, proc = train_step(model, proc[0], proc[1]), proc[2]
|
|
|
|
pt = time.perf_counter()
|
|
|
|
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
|
try:
|
|
next_proc = data_get(it)
|
|
except StopIteration:
|
|
next_proc = None
|
|
|
|
dt = time.perf_counter()
|
|
|
|
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
|
loss = loss.numpy().item()
|
|
|
|
cl = time.perf_counter()
|
|
|
|
if BENCHMARK: step_times.append(cl - st)
|
|
|
|
tqdm.write(
|
|
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
|
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, "
|
|
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
|
)
|
|
|
|
if WANDB:
|
|
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt,
|
|
"train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
|
|
|
|
st = cl
|
|
prev_cookies.append(proc)
|
|
proc, next_proc = next_proc, None # return old cookie
|
|
i += 1
|
|
|
|
if i == BENCHMARK:
|
|
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
|
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
|
|
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
|
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
|
|
return
|
|
|
|
with Context(BEAM=EVAL_BEAM):
|
|
if epoch == next_eval_at:
|
|
next_eval_at += evaluate_every
|
|
eval_loss = []
|
|
scores = []
|
|
|
|
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=VAL_PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
|
|
eval_loss_value, score = eval_step(model, x, y)
|
|
eval_loss.append(eval_loss_value)
|
|
scores.append(score)
|
|
|
|
scores = Tensor.mean(Tensor.stack(*scores, dim=0), axis=0).numpy()
|
|
eval_loss = Tensor.mean(Tensor.stack(*eval_loss, dim=0), axis=0).numpy()
|
|
|
|
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
|
|
mean_dice = (l2_dice + l1_dice) / 2
|
|
|
|
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
|
|
|
|
if WANDB:
|
|
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
|
|
|
|
if mean_dice >= TARGET_METRIC:
|
|
is_successful = True
|
|
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
|
|
elif mean_dice < 1e-6:
|
|
print("Model diverging. Aborting.")
|
|
diverged = True
|
|
|
|
if not is_successful and CKPT:
|
|
if WANDB and wandb.run is not None:
|
|
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
|
|
else:
|
|
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
|
|
|
|
save_checkpoint(get_state_dict(model), fn)
|
|
|
|
if is_successful or diverged:
|
|
break
|
|
|
|
def train_rnnt():
|
|
# TODO: RNN-T
|
|
pass
|
|
|
|
@TinyJit
|
|
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
|
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
|
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
|
t.shard_(GPUS, axis=0)
|
|
optimizer.zero_grad()
|
|
|
|
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
|
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
|
(loss * loss_scaler).backward()
|
|
|
|
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
|
|
for p in optimizer.params:
|
|
p.grad = p.grad / loss_scaler
|
|
global_norm += p.grad.float().square().sum()
|
|
global_norm = global_norm.sqrt()
|
|
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
|
|
|
|
optimizer.step()
|
|
scheduler.step()
|
|
return loss.realize(), global_norm.realize()
|
|
|
|
@TinyJit
|
|
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
|
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
|
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
|
t.shard_(GPUS, axis=0)
|
|
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
|
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
|
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
|
return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize()
|
|
|
|
def train_bert():
|
|
# NOTE: pip install tensorflow, wandb required
|
|
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
|
|
from examples.mlperf.helpers import get_mlperf_bert_model, get_fake_data_bert
|
|
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
|
|
|
config = {}
|
|
BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
|
|
|
|
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
|
print(f"training on {GPUS}")
|
|
for x in GPUS: Device[x]
|
|
seed = config["seed"] = getenv("SEED", 12345)
|
|
|
|
INITMLPERF = getenv("INITMLPERF")
|
|
RUNMLPERF = getenv("RUNMLPERF")
|
|
BENCHMARK = getenv("BENCHMARK")
|
|
if getenv("LOGMLPERF"):
|
|
from mlperf_logging import mllog
|
|
import mlperf_logging.mllog.constants as mllog_constants
|
|
|
|
mllog.config(filename=f"result_bert_{seed}.log")
|
|
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
|
|
MLLOGGER = mllog.get_mllogger()
|
|
MLLOGGER.logger.propagate = False
|
|
|
|
if INITMLPERF:
|
|
assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
|
|
|
|
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
|
|
|
|
diskcache_clear()
|
|
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
|
|
MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
|
|
|
|
if RUNMLPERF:
|
|
MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
|
|
MLLOGGER.event(key=mllog_constants.SEED, value=seed)
|
|
else:
|
|
MLLOGGER = None
|
|
|
|
# ** hyperparameters **
|
|
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
|
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
|
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00011 * math.sqrt(BS/66))
|
|
|
|
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS)
|
|
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
|
|
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
|
|
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
|
|
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
|
|
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
|
|
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
|
|
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
|
|
|
|
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0)
|
|
decay = config["DECAY"] = getenv("DECAY", 0.01)
|
|
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
|
|
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
|
|
|
|
target, achieved = getenv("TARGET", 0.72), False
|
|
|
|
config["DEFAULT_FLOAT"] = dtypes.default_float.name
|
|
config["DISABLE_DROPOUT"] = getenv("DISABLE_DROPOUT", 0)
|
|
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
|
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
|
|
|
Tensor.manual_seed(seed) # seed for weight initialization
|
|
|
|
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
|
|
|
|
# ** init wandb **
|
|
WANDB = getenv("WANDB")
|
|
if WANDB:
|
|
import wandb
|
|
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
|
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
|
|
|
|
# ** init model **
|
|
|
|
model = get_mlperf_bert_model()
|
|
if RUNMLPERF:
|
|
model.load_from_pretrained(init_ckpt)
|
|
else:
|
|
# for init, zero out all weights
|
|
for p in get_parameters(model):
|
|
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
|
|
|
|
parameters = get_parameters(model)
|
|
for p in parameters:
|
|
p.to_(GPUS)
|
|
|
|
# ** Log run config **
|
|
for key, value in config.items(): print(f'HParam: "{key}": {value}')
|
|
|
|
# ** Optimizer **
|
|
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
|
|
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
|
|
optimizer_wd = LAMB(parameters, lr=max_lr, eps=epsilon, weight_decay=decay, adam=False)
|
|
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=epsilon, weight_decay=0.0, adam=False)
|
|
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
|
|
|
|
# ** LR scheduler **
|
|
scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
|
scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
|
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
|
|
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
|
|
|
|
# log mlperf hparams
|
|
if MLLOGGER:
|
|
if RUNMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["GLOBAL_BATCH_SIZE"])
|
|
MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=512)
|
|
MLLOGGER.event(key="max_predictions_per_seq", value=76)
|
|
|
|
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="LAMB")
|
|
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["OPT_BASE_LEARNING_RATE"])
|
|
MLLOGGER.event(key=mllog_constants.OPT_LAMB_WEIGHT_DECAY, value=config["DECAY"])
|
|
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_1, value=optimizer_wd.b1)
|
|
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_2, value=optimizer_wd.b2)
|
|
MLLOGGER.event(key=mllog_constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=config["POLY_POWER"])
|
|
MLLOGGER.event(key=mllog_constants.OPT_LAMB_EPSILON, value=config["EPSILON"])
|
|
|
|
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
|
|
MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
|
|
MLLOGGER.event(key='start_warmup_step', value=0)
|
|
MLLOGGER.event(key='opt_learning_rate_training_steps', value=config["TRAIN_STEPS"])
|
|
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
|
|
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["EVAL_BS"] * config["MAX_EVAL_STEPS"])
|
|
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["GLOBAL_BATCH_SIZE"] * config["TRAIN_STEPS"])
|
|
|
|
# ** resume from checkpointing **
|
|
start_step = 0
|
|
previous_step = None
|
|
if ckpt:=getenv("RESUME", ""):
|
|
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
|
|
start_step = int(scheduler_wd.epoch_counter.item())
|
|
print(f"resuming from {ckpt} at step {start_step}")
|
|
|
|
if RUNMLPERF:
|
|
# only load real data with RUNMLPERF
|
|
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
|
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
|
for _ in range(start_step): next(train_it) # Fast forward
|
|
else:
|
|
# repeat fake data
|
|
def repeat_fake(bs):
|
|
while True: yield get_fake_data_bert(bs)
|
|
eval_it = iter(repeat_fake(EVAL_BS))
|
|
train_it = iter(repeat_fake(BS))
|
|
|
|
step_times = []
|
|
# ** train loop **
|
|
wc_start = time.perf_counter()
|
|
|
|
i, train_data = start_step, next(train_it)
|
|
|
|
if RUNMLPERF:
|
|
if MLLOGGER:
|
|
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
|
|
|
|
while train_data is not None and i < train_steps and not achieved:
|
|
Tensor.training = True
|
|
BEAM.value = TRAIN_BEAM
|
|
st = time.perf_counter()
|
|
GlobalCounters.reset()
|
|
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
|
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
|
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
|
|
|
|
pt = time.perf_counter()
|
|
|
|
try:
|
|
next_data = next(train_it)
|
|
except StopIteration:
|
|
next_data = None
|
|
|
|
dt = time.perf_counter()
|
|
|
|
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
|
loss = loss.item()
|
|
|
|
cl = time.perf_counter()
|
|
if BENCHMARK: step_times.append(cl - st)
|
|
|
|
tqdm.write(
|
|
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
|
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
|
|
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
|
if WANDB:
|
|
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
|
|
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
|
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
|
|
|
|
train_data, next_data = next_data, None
|
|
i += 1
|
|
|
|
if i == BENCHMARK:
|
|
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
|
estimated_total_minutes = int(median_step_time * train_steps / 60)
|
|
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
|
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
|
|
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
|
|
|
|
# ** eval loop **
|
|
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
|
|
if getenv("RESET_STEP", 0) or INITMLPERF: train_step_bert.reset()
|
|
else: train_step_bert.captured.free_intermediates()
|
|
eval_lm_losses = []
|
|
eval_clsf_losses = []
|
|
eval_lm_accs = []
|
|
eval_clsf_accs = []
|
|
eval_times = []
|
|
Tensor.training = False
|
|
BEAM.value = EVAL_BEAM
|
|
|
|
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
|
eval_data = next(eval_it)
|
|
GlobalCounters.reset()
|
|
st = time.time()
|
|
|
|
lm_acc, clsf_acc, lm_loss, clsf_loss = eval_step_bert(model,
|
|
eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
|
|
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"], GPUS)
|
|
lm_acc, clsf_acc, lm_loss, clsf_loss = lm_acc.item(), clsf_acc.item(), lm_loss.item(), clsf_loss.item()
|
|
|
|
eval_lm_losses.append(lm_loss)
|
|
eval_clsf_losses.append(clsf_loss)
|
|
eval_lm_accs.append(lm_acc)
|
|
eval_clsf_accs.append(clsf_acc)
|
|
|
|
et = time.time()
|
|
eval_times.append(et - st)
|
|
|
|
if BENCHMARK and j == BENCHMARK:
|
|
# assume INITMLPERF has BENCHMARK set
|
|
if MLLOGGER and INITMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None)
|
|
return
|
|
|
|
if getenv("RESET_STEP", 0): eval_step_bert.reset()
|
|
else: eval_step_bert.captured.free_intermediates()
|
|
del eval_data
|
|
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
|
|
avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
|
|
avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)
|
|
avg_clsf_acc = sum(eval_clsf_accs) / len(eval_clsf_accs)
|
|
avg_fw_time = sum(eval_times) / len(eval_times)
|
|
results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
|
|
eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
|
|
tqdm.write(results)
|
|
|
|
if WANDB:
|
|
wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
|
|
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
|
|
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i*BS, metadata={"epoch_count": i*BS, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
|
|
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": i*BS, "masked_lm_accuracy": avg_lm_acc})
|
|
|
|
# save model if achieved target
|
|
if not achieved and avg_lm_acc >= target:
|
|
wc_end = time.perf_counter()
|
|
if getenv("CKPT"):
|
|
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
|
fn = f"{ckpt_dir}/bert-large.safe"
|
|
safe_save(get_state_dict(model), fn)
|
|
print(f" *** Model saved to {fn} ***")
|
|
|
|
total_seconds = wc_end - wc_start
|
|
hours = int(total_seconds // 3600)
|
|
minutes = int((total_seconds % 3600) // 60)
|
|
seconds = total_seconds % 60
|
|
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
|
|
achieved = True
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i*BS, metadata={"epoch_num": i*BS})
|
|
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
|
|
# stop once hitting the target
|
|
break
|
|
|
|
if getenv("CKPT") and i % save_ckpt_freq == 0:
|
|
if MLLOGGER and RUNMLPERF:
|
|
if previous_step:
|
|
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
|
|
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
|
|
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
|
if WANDB and wandb.run is not None:
|
|
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
|
|
else:
|
|
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
|
|
print(f"saving ckpt to {fn}")
|
|
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
|
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
|
|
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
|
|
while len(ckpt_files) > keep_ckpt_amount:
|
|
last = ckpt_files.pop(0)
|
|
print(f"Removing old ckpt {last}")
|
|
os.remove(os.path.join(ckpt_dir, last))
|
|
if MLLOGGER and RUNMLPERF:
|
|
MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
|
|
MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": i * BS, "step_num": i, "first_step_num": i+1})
|
|
previous_step = i
|
|
|
|
def train_maskrcnn():
|
|
# TODO: Mask RCNN
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
multiprocessing.set_start_method('spawn')
|
|
with Tensor.train():
|
|
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
|
nm = f"train_{m}"
|
|
if nm in globals():
|
|
print(f"training {m}")
|
|
globals()[nm]()
|