fix TRAIN_BEAM and Tensor.training for mlperf bert (#4525)

also hard coded bert model config instead of looking up a file
This commit is contained in:
chenyu
2024-05-11 00:18:36 -04:00
committed by GitHub
parent 7fab8c9e17
commit b00b6b16f0
3 changed files with 24 additions and 8 deletions

View File

@@ -33,7 +33,7 @@ if __name__ == "__main__":
Tensor.training = False
model = get_mlperf_bert_model(os.path.join(BASEDIR, "bert_config.json"))
model = get_mlperf_bert_model()
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
for _, x in get_state_dict(model).items():