mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix bert free_intermediates (#9633)
fix when only run eval `TRAIN=0 BERT_SIZE=tiny examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh`
This commit is contained in:
@@ -826,7 +826,7 @@ def train_bert():
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
|
||||
if getenv("RESET_STEP", 0): train_step_bert.reset()
|
||||
else: train_step_bert.captured.free_intermediates()
|
||||
elif train_step_bert.captured is not None: train_step_bert.captured.free_intermediates()
|
||||
eval_lm_losses = []
|
||||
eval_clsf_losses = []
|
||||
eval_lm_accs = []
|
||||
@@ -860,7 +860,7 @@ def train_bert():
|
||||
return
|
||||
|
||||
if getenv("RESET_STEP", 0): eval_step_bert.reset()
|
||||
else: eval_step_bert.captured.free_intermediates()
|
||||
elif eval_step_bert.captured is not None: eval_step_bert.captured.free_intermediates()
|
||||
del eval_data
|
||||
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
|
||||
avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
|
||||
|
||||
Reference in New Issue
Block a user