mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix TRAIN_BEAM and Tensor.training for mlperf bert (#4525)
also hard coded bert model config instead of looking up a file
This commit is contained in:
@@ -33,7 +33,7 @@ if __name__ == "__main__":
|
||||
|
||||
Tensor.training = False
|
||||
|
||||
model = get_mlperf_bert_model(os.path.join(BASEDIR, "bert_config.json"))
|
||||
model = get_mlperf_bert_model()
|
||||
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
|
||||
Reference in New Issue
Block a user