more bert updates (#13597)

prep split jit
also lower BS to 72
This commit is contained in:
chenyu
2025-12-06 08:32:43 -05:00
committed by GitHub
parent 93f1baca77
commit 4562f217e1
4 changed files with 31 additions and 31 deletions

View File

@@ -918,32 +918,6 @@ def train_rnnt():
# TODO: RNN-T
pass
@TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
else: t.to_(GPUS[0])
optimizer.zero_grad()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
(loss * loss_scaler).backward()
global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer[0].device)
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt().contiguous()
for p in optimizer.params:
p.grad = (global_norm > 1.0).where((p.grad/global_norm).cast(p.grad.dtype), p.grad)
optimizer.step()
scheduler.step()
# TODO: no to("CPU") here because it blocks and messes the python time
Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr)
return loss, global_norm, optimizer.optimizers[0].lr
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
@@ -1130,6 +1104,32 @@ def train_bert():
if MLLOGGER:
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*GBS, metadata={"epoch_num": i*GBS})
@TinyJit
def train_step_bert(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
else: t.to_(GPUS[0])
optimizer_group.zero_grad()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
(loss * loss_scaler).backward()
global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer_group[0].device)
for p in optimizer_group.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt().contiguous()
for p in optimizer_group.params:
p.grad = (global_norm > 1.0).where((p.grad/global_norm).cast(p.grad.dtype), p.grad)
optimizer_group.step()
scheduler_group.step()
# TODO: no to("CPU") here because it blocks and messes the python time
Tensor.realize(loss, global_norm, optimizer_group.optimizers[0].lr)
return loss, global_norm, optimizer_group.optimizers[0].lr
while train_data is not None and i < train_steps and not achieved:
if getenv("TRAIN", 1):
Tensor.training = True
@@ -1137,9 +1137,9 @@ def train_bert():
st = time.perf_counter()
GlobalCounters.reset()
with WallTimeEvent(BenchEvent.STEP):
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
loss, global_norm, lr = train_step_bert(
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
pt = time.perf_counter()

View File

@@ -2,7 +2,7 @@
export PYTHONPATH="." NV=1
export MODEL="bert"
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=72 EVAL_BS=72
export IGNORE_OOB=1
export REWRITE_STACK_LIMIT=500000

View File

@@ -2,7 +2,7 @@
export PYTHONPATH="." NV=1
export MODEL="bert"
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=72 EVAL_BS=72
export IGNORE_OOB=1
export REWRITE_STACK_LIMIT=500000

View File

@@ -5,7 +5,7 @@ set -o pipefail # Make pipeline fail if any command fails
export PYTHONPATH="." NV=1
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=72 EVAL_BS=72
export IGNORE_OOB=1
export REWRITE_STACK_LIMIT=500000