bert put eval copy and getting lr in jit (#9350)

This commit is contained in:
chenyu
2025-03-04 20:57:03 -05:00
committed by GitHub
parent 7576a1da23
commit ad72269f08

View File

@@ -592,7 +592,8 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
optimizer.step()
scheduler.step()
return loss.realize(), global_norm.realize()
# TODO: no to("CPU") here because it blocks and messes the python time
return loss.realize(), global_norm.realize(), optimizer.optimizers[0].lr.realize()
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
@@ -603,7 +604,8 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize()
return masked_lm_accuracy.to("CPU").realize(), seq_relationship_accuracy.realize().to("CPU"), \
masked_lm_loss.to("CPU").realize(), next_sentence_loss.to("CPU").realize()
def train_bert():
# NOTE: pip install tensorflow, wandb required
@@ -771,14 +773,13 @@ def train_bert():
if MLLOGGER:
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
# TODO: put copy into jit
while train_data is not None and i < train_steps and not achieved:
if getenv("TRAIN", 1):
Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
@@ -791,18 +792,19 @@ def train_bert():
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
device_str = parameters[0].device if isinstance(parameters[0].device, str) else f"{parameters[0].device[0]} * {len(parameters[0].device)}"
loss = loss.item()
lr = lr.item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {lr:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
wandb.log({"lr": lr, "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})