hotfix bert no shard with only one device (#9243)

`LLVM=1 BERT_SIZE="tiny" DEFAULT_FLOAT=HALF BENCHMARK=5 MODEL="bert" python3 examples/mlperf/model_train.py` runs for me with this. it should not failed with single device shard though
This commit is contained in:
chenyu
2025-02-25 09:05:11 -05:00
committed by GitHub
parent bba9c22f53
commit 6610ad58ab

View File

@@ -574,8 +574,9 @@ def train_rnnt():
@TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
t.shard_(GPUS, axis=0)
if len(GPUS) > 1:
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
t.shard_(GPUS, axis=0)
optimizer.zero_grad()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
@@ -596,8 +597,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
t.shard_(GPUS, axis=0)
if len(GPUS) > 1:
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
t.shard_(GPUS, axis=0)
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
@@ -696,8 +698,9 @@ def train_bert():
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
parameters = get_parameters(model)
for p in parameters:
p.to_(GPUS)
if len(GPUS) > 1:
for p in parameters:
p.to_(GPUS)
# ** Log run config **
for key, value in config.items(): print(f'HParam: "{key}": {value}')