Files
tinygrad/examples/mlperf/model_train.py

612 lines
27 KiB
Python

import os, time, math, functools
from pathlib import Path
from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from extra.lr_scheduler import LRSchedulerGroup
from examples.mlperf.helpers import get_training_state, load_training_state
def train_resnet():
from extra.models import resnet
from examples.mlperf.dataloader import batch_load_resnet
from extra.datasets.imagenet import get_train_files, get_val_files
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from examples.mlperf.initializers import Conv2dHeNormal, Linear
from examples.hlb_cifar10 import UnsyncedBatchNorm
config = {}
seed = config["seed"] = getenv("SEED", 42)
Tensor.manual_seed(seed) # seed for weight initialization
INITMLPERF = getenv("INITMLPERF")
RUNMLPERF = getenv("RUNMLPERF")
if getenv("LOGMLPERF"):
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename=f"result_{seed}.txt")
mllog.config(root_dir=Path(__file__).parents[3].as_posix()) # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py"
MLLOGGER = mllog.get_mllogger()
if INITMLPERF:
# common.yaml
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
# closed_common.yaml
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RESNET)
diskcache_clear()
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
MLLOGGER.start(key=mllog_constants.INIT_START)
if RUNMLPERF:
MLLOGGER.start(key=mllog_constants.RUN_START)
MLLOGGER.event(key=mllog_constants.SEED, value=seed)
else:
MLLOGGER = None
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"training on {GPUS}")
for x in GPUS: Device[x]
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
# ** model definition and initializers **
num_classes = 1000
resnet.Conv2d = Conv2dHeNormal
resnet.Linear = Linear
if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
model = resnet.ResNet50(num_classes)
# shard weights and initialize in order
for k, x in get_state_dict(model).items():
if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
x.realize().shard_(GPUS, axis=0)
else:
x.realize().to_(GPUS)
parameters = get_parameters(model)
# ** hyperparameters **
epochs = config["epochs"] = getenv("EPOCHS", 37)
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
base_lr = config["base_lr"] = getenv("LR", 7.2 * (BS/1536))
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
decay = config["decay"] = getenv("DECAY", 2e-4)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 128.0 if dtypes.default_float == dtypes.float16 else 1.0)
target, achieved = getenv("TARGET", 0.759), False
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
eval_freq = getenv("EVAL_FREQ", 1)
steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["BEAM"] = BEAM.value
config["TRAIN_BEAM"] = TRAIN_BEAM
config["EVAL_BEAM"] = EVAL_BEAM
config["WINO"] = WINO.value
config["SYNCBN"] = getenv("SYNCBN")
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
print(f"training with batch size {BS} for {epochs} epochs")
# log mlperf hparams
if MLLOGGER:
if RUNMLPERF:
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=BS)
from extra.datasets.imagenet import get_train_files, get_val_files
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=len(get_train_files()))
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=len(get_val_files()))
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="lars")
assert scheduler.initial_lr == scheduler_skip.initial_lr
assert scheduler.end_lr == scheduler_skip.end_lr
assert scheduler.power == scheduler_skip.power
MLLOGGER.event(key=mllog_constants.LARS_OPT_BASE_LEARNING_RATE, value=scheduler.initial_lr)
MLLOGGER.event(key=mllog_constants.LARS_OPT_END_LR, value=scheduler.end_lr)
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_POLY_POWER, value=scheduler.power)
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_STEPS, value=epochs)
MLLOGGER.event(key=mllog_constants.LARS_EPSILON, value=0) # does not support epsilon != 0
MLLOGGER.event(key=mllog_constants.LARS_OPT_LEARNING_RATE_WARMUP_EPOCHS, value=lr_warmup_epochs)
MLLOGGER.event(key=mllog_constants.LARS_OPT_MOMENTUM, value=optimizer.momentum)
MLLOGGER.event(key=mllog_constants.LARS_OPT_WEIGHT_DECAY, value=optimizer.wd)
# ** resume from checkpointing **
start_epoch = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
print(f"resuming from {ckpt} at epoch {start_epoch}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args)
BENCHMARK = getenv("BENCHMARK")
# ** jitted steps **
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
# mlperf reference resnet does not divide by input_std for some reason
# input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
@TinyJit
def train_step(X, Y):
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
(loss * loss_scaler).backward()
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
optimizer_group.step()
scheduler_group.step()
return loss.realize(), top_1.realize()
@TinyJit
def eval_step(X, Y):
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
return loss.realize(), top_1.realize()
def fake_data_get(batch_size):
x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
y = [0] * batch_size
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
# ** epoch loop **
step_times = []
for e in range(start_epoch, epochs):
# ** train loop **
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
Tensor.training = True
BEAM.value = TRAIN_BEAM
if INITMLPERF:
i, proc = 0, fake_data_get(BS)
else:
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, data_get(it)
prev_cookies = []
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
(loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
pt = time.perf_counter()
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
if INITMLPERF:
next_proc = fake_data_get(BS)
else:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
top_1_acc = top_1 / sum(yi != -1 for yi in y)
cl = time.perf_counter()
if BENCHMARK:
step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
st = cl
prev_cookies.append(proc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
assert not math.isnan(loss)
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
# if we are doing beam search, run the first eval too
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e+1, metadata=dict(epoch_num=e+1))
# ** eval loop **
# always eval for epoch >= 33 to stop the clock as soon as eval target hits, it can converge in epoch in [33, 37]
if steps_in_val_epoch > 0 and ((e + 1 - eval_start_epoch) % eval_freq == 0 or e + 1 >= 33):
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=e+1, metadata=dict(epoch_num=e+1))
if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
eval_times = []
eval_loss = 0.0
eval_top_1 = 0
eval_num_samples = 0
Tensor.training = False
BEAM.value = EVAL_BEAM
if INITMLPERF:
i, proc = 0, fake_data_get(EVAL_BS)
else:
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
i, proc = 0, data_get(it)
prev_cookies = []
while proc is not None:
GlobalCounters.reset()
st = time.time()
(loss, top_1), y, proc = eval_step(proc[0], proc[1]), proc[2], proc[3] # drop inputs, keep cookie
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
if INITMLPERF:
next_proc = fake_data_get(EVAL_BS)
else:
next_proc = data_get(it)
except StopIteration:
next_proc = None
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
num_samples = sum(yi != -1 for yi in y)
eval_loss += loss * num_samples
eval_top_1 += top_1
eval_num_samples += num_samples
prev_cookies.append(proc)
proc, next_proc = next_proc, None
i += 1
if i == BENCHMARK:
# assume INITMLPERF has BENCHMARK set
if MLLOGGER and INITMLPERF:
MLLOGGER.event(key=mllog_constants.INIT_STOP)
return
et = time.time()
eval_times.append(et - st)
if getenv("RESET_STEP", 1): eval_step.reset()
if not BENCHMARK:
assert eval_num_samples == len(get_val_files()), f"eval sample count mismatched. {eval_num_samples=} != {len(get_val_files())}"
total_loss = eval_loss / eval_num_samples
total_top_1 = eval_top_1 / eval_num_samples
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
if WANDB:
wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=total_top_1, metadata=dict(epoch_num=e+1))
MLLOGGER.event(key=mllog_constants.EVAL_STOP, value=e+1, metadata=dict(epoch_num=e+1))
# save model if achieved target
if not achieved and total_top_1 >= target:
# stop once achieve the target
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
fn = f"./ckpts/resnet50_{seed}.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
achieved = True
break
# checkpoint every time we eval
if getenv("CKPT"):
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
@TinyJit
def train_step_bert(model, optimizer, scheduler, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
loss = lm_loss + clsf_loss
if not getenv('DISABLE_BACKWARD', 0):
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
return loss.realize()
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
clsf_accuracy = (clsf_predictions == next_sentence_labels).float().mean()
mlm_predictions = lm_logits.log_softmax().argmax(-1)
mask = (masked_lm_weights == 1.0)
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.float().sum()
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
return {
"masked_lm_accuracy": mlm_accuracy.realize(),
"masked_lm_loss": lm_loss.realize(),
"next_sentence_accuracy": clsf_accuracy.realize(),
"next_sentence_loss": clsf_loss.realize()
}
def train_bert():
# NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {}
BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"training on {GPUS}")
for x in GPUS: Device[x]
seed = config["seed"] = getenv("SEED", 12345)
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 4 * len(GPUS)) # FP32 4090: 6 GPUS -> BS24
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 4 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000004166 * BS)
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", train_steps // 10)
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
decay = config["decay"] = getenv("DECAY", 0.01)
poly_power = config["poly_power"] = getenv("POLY_POWER", 1.0)
target, achieved = getenv("TARGET", 0.72), False
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model()
if init_ckpt: init_bert_from_checkpoint(model, init_ckpt)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)
parameters = get_parameters(model)
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
# ** Log hparams **
for key, value in config.items():
print(f'HParam: "{key}": {value}')
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LAMB(parameters, 1 / warmup_steps, eps=1e-6, wd=decay, adam=False)
optimizer_skip = LAMB(skip_list, 1 / warmup_steps, eps=1e-6, wd=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power)
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
# ** resume from checkpointing **
start_step = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler, safe_load(ckpt))
start_step = scheduler.epoch_counter.numpy().item()
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
BENCHMARK = getenv("BENCHMARK")
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
step_times = []
# ** train loop **
wc_start = time.perf_counter()
i, train_data = 0, get_data_bert(GPUS, train_it)
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step_bert(model, optimizer_group, scheduler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
pt = time.perf_counter()
try:
next_data = get_data_bert(GPUS, train_it)
except StopIteration:
next_data = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
train_data, next_data = next_data, None
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * train_steps / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
return
# ** eval loop **
if i % eval_step_freq == 0 or i == 1:
train_step_bert.reset() # free the train step memory :(
eval_lm_losses = []
eval_clsf_losses = []
eval_lm_accs = []
eval_clsf_accs = []
eval_times = []
Tensor.training = False
BEAM.value = EVAL_BEAM
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
eval_data = get_data_bert(GPUS, eval_it)
GlobalCounters.reset()
st = time.time()
eval_result: dict[str, Tensor] = eval_step_bert(model,
eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item()
lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item()
eval_lm_losses.append(lm_loss)
eval_clsf_losses.append(clsf_loss)
eval_lm_accs.append(lm_acc)
eval_clsf_accs.append(clsf_acc)
et = time.time()
eval_times.append(et - st)
if BENCHMARK and j == BENCHMARK: break
eval_step_bert.reset()
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)
avg_clsf_acc = sum(eval_clsf_accs) / len(eval_clsf_accs)
avg_fw_time = sum(eval_times) / len(eval_times)
results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
tqdm.write(results)
with open(getenv("EVAL_LOG", "./eval_log.txt"), "a") as file: file.write(results + "\n")
if WANDB:
wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
# save model if achieved target
if not achieved and avg_lm_acc >= target:
wc_end = time.perf_counter()
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/bert-large.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
total_seconds = wc_end - wc_start
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = total_seconds % 60
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
achieved = True
# stop once hitting the target
break
if getenv("CKPT") and i % save_ckpt_freq == 0:
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
else:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler), fn)
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
while len(ckpt_files) > keep_ckpt_amount:
last = ckpt_files.pop(0)
print(f"Removing old ckpt {last}")
os.remove(os.path.join(ckpt_dir, last))
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()