mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
hotfix bert no shard with only one device (#9243)
`LLVM=1 BERT_SIZE="tiny" DEFAULT_FLOAT=HALF BENCHMARK=5 MODEL="bert" python3 examples/mlperf/model_train.py` runs for me with this. it should not failed with single device shard though
This commit is contained in:
@@ -574,8 +574,9 @@ def train_rnnt():
|
||||
@TinyJit
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
if len(GPUS) > 1:
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
optimizer.zero_grad()
|
||||
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
@@ -596,8 +597,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
@TinyJit
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
||||
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
if len(GPUS) > 1:
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
||||
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
@@ -696,8 +698,9 @@ def train_bert():
|
||||
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
|
||||
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters:
|
||||
p.to_(GPUS)
|
||||
if len(GPUS) > 1:
|
||||
for p in parameters:
|
||||
p.to_(GPUS)
|
||||
|
||||
# ** Log run config **
|
||||
for key, value in config.items(): print(f'HParam: "{key}": {value}')
|
||||
|
||||
Reference in New Issue
Block a user