|
|
|
|
@@ -918,32 +918,6 @@ def train_rnnt():
|
|
|
|
|
# TODO: RNN-T
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@TinyJit
|
|
|
|
|
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
|
|
|
|
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
|
|
|
|
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
|
|
|
|
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
|
|
|
|
else: t.to_(GPUS[0])
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
|
|
|
|
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
|
|
|
|
(loss * loss_scaler).backward()
|
|
|
|
|
|
|
|
|
|
global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer[0].device)
|
|
|
|
|
for p in optimizer.params:
|
|
|
|
|
p.grad = p.grad / loss_scaler
|
|
|
|
|
global_norm += p.grad.float().square().sum()
|
|
|
|
|
global_norm = global_norm.sqrt().contiguous()
|
|
|
|
|
for p in optimizer.params:
|
|
|
|
|
p.grad = (global_norm > 1.0).where((p.grad/global_norm).cast(p.grad.dtype), p.grad)
|
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
scheduler.step()
|
|
|
|
|
# TODO: no to("CPU") here because it blocks and messes the python time
|
|
|
|
|
Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr)
|
|
|
|
|
return loss, global_norm, optimizer.optimizers[0].lr
|
|
|
|
|
|
|
|
|
|
@TinyJit
|
|
|
|
|
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
|
|
|
|
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
|
|
|
|
@@ -1130,6 +1104,32 @@ def train_bert():
|
|
|
|
|
if MLLOGGER:
|
|
|
|
|
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*GBS, metadata={"epoch_num": i*GBS})
|
|
|
|
|
|
|
|
|
|
@TinyJit
|
|
|
|
|
def train_step_bert(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
|
|
|
|
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
|
|
|
|
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
|
|
|
|
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
|
|
|
|
else: t.to_(GPUS[0])
|
|
|
|
|
optimizer_group.zero_grad()
|
|
|
|
|
|
|
|
|
|
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
|
|
|
|
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
|
|
|
|
(loss * loss_scaler).backward()
|
|
|
|
|
|
|
|
|
|
global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer_group[0].device)
|
|
|
|
|
for p in optimizer_group.params:
|
|
|
|
|
p.grad = p.grad / loss_scaler
|
|
|
|
|
global_norm += p.grad.float().square().sum()
|
|
|
|
|
global_norm = global_norm.sqrt().contiguous()
|
|
|
|
|
for p in optimizer_group.params:
|
|
|
|
|
p.grad = (global_norm > 1.0).where((p.grad/global_norm).cast(p.grad.dtype), p.grad)
|
|
|
|
|
|
|
|
|
|
optimizer_group.step()
|
|
|
|
|
scheduler_group.step()
|
|
|
|
|
# TODO: no to("CPU") here because it blocks and messes the python time
|
|
|
|
|
Tensor.realize(loss, global_norm, optimizer_group.optimizers[0].lr)
|
|
|
|
|
return loss, global_norm, optimizer_group.optimizers[0].lr
|
|
|
|
|
|
|
|
|
|
while train_data is not None and i < train_steps and not achieved:
|
|
|
|
|
if getenv("TRAIN", 1):
|
|
|
|
|
Tensor.training = True
|
|
|
|
|
@@ -1137,9 +1137,9 @@ def train_bert():
|
|
|
|
|
st = time.perf_counter()
|
|
|
|
|
GlobalCounters.reset()
|
|
|
|
|
with WallTimeEvent(BenchEvent.STEP):
|
|
|
|
|
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
|
|
|
|
loss, global_norm, lr = train_step_bert(
|
|
|
|
|
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
|
|
|
|
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
|
|
|
|
|
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
|
|
|
|
|
|
|
|
|
|
pt = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|