mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add global_batch_size to mlperf bert (#10852)
global_batch_size = grad_acc_steps * batch_size. no-op change to prep grad acc for bert
This commit is contained in:
@@ -999,16 +999,20 @@ def train_bert():
|
||||
MLLOGGER = None
|
||||
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
BS = config["BS"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||
# TODO: implement grad accumulation + mlperf logging
|
||||
assert grad_acc == 1
|
||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(BS/96))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(GBS/96))
|
||||
opt_lamb_beta_1 = config["OPT_LAMB_BETA_1"] = getenv("OPT_LAMB_BETA_1", 0.9)
|
||||
opt_lamb_beta_2 = config["OPT_LAMB_BETA_2"] = getenv("OPT_LAMB_BETA_2", 0.999)
|
||||
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3600000 // BS)
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3600000 // GBS)
|
||||
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
|
||||
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
|
||||
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
|
||||
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * GBS + 3000000) / 25000) * 25000) / GBS)) # Round down
|
||||
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
|
||||
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
|
||||
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
|
||||
@@ -1066,7 +1070,7 @@ def train_bert():
|
||||
scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
||||
scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
||||
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
|
||||
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
|
||||
print(f"training with global batch size {GBS} for one epoch with {train_steps} steps")
|
||||
|
||||
# log mlperf hparams
|
||||
if MLLOGGER:
|
||||
@@ -1119,7 +1123,7 @@ def train_bert():
|
||||
|
||||
if RUNMLPERF:
|
||||
if MLLOGGER:
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*GBS, metadata={"epoch_num": i*GBS})
|
||||
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
if getenv("TRAIN", 1):
|
||||
@@ -1156,7 +1160,7 @@ def train_bert():
|
||||
if WANDB:
|
||||
wandb.log({"lr": lr, "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*GBS})
|
||||
|
||||
train_data, next_data = next_data, None
|
||||
i += 1
|
||||
@@ -1171,7 +1175,7 @@ def train_bert():
|
||||
# ** eval loop **
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*GBS, "step_num": i})
|
||||
if getenv("RESET_STEP"): train_step_bert.reset()
|
||||
elif getenv("FREE_INTERMEDIATE", 1) and train_step_bert.captured is not None: train_step_bert.captured.free_intermediates()
|
||||
eval_lm_losses = []
|
||||
@@ -1221,11 +1225,11 @@ def train_bert():
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
|
||||
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time, "epoch": (i+1)*BS})
|
||||
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time, "epoch": (i+1)*GBS})
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i*BS, metadata={"epoch_count": i*BS, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": i*BS, "masked_lm_accuracy": avg_lm_acc})
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i*GBS, metadata={"epoch_count": i*GBS, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": i*GBS, "masked_lm_accuracy": avg_lm_acc})
|
||||
|
||||
# save model if achieved target
|
||||
if not achieved and avg_lm_acc >= target:
|
||||
@@ -1240,10 +1244,10 @@ def train_bert():
|
||||
hours = int(total_seconds // 3600)
|
||||
minutes = int((total_seconds % 3600) // 60)
|
||||
seconds = total_seconds % 60
|
||||
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
|
||||
print(f"Reference Convergence point reached after {i * GBS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
|
||||
achieved = True
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i*BS, metadata={"epoch_num": i*BS})
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i*GBS, metadata={"epoch_num": i*GBS})
|
||||
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
|
||||
# stop once hitting the target
|
||||
break
|
||||
@@ -1271,13 +1275,9 @@ def train_bert():
|
||||
os.remove(os.path.join(ckpt_dir, last))
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
|
||||
MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": i * BS, "step_num": i, "first_step_num": i+1})
|
||||
MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": i * GBS, "step_num": i, "first_step_num": i+1})
|
||||
previous_step = i
|
||||
|
||||
def train_maskrcnn():
|
||||
# TODO: Mask RCNN
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user