mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
resnet print epoch ops and mem in benchmark (#4244)
* resnet print epoch ops and mem in benchmark also added a flag to optionally disable reset jitted steps * real per epoch stats
This commit is contained in:
@@ -181,13 +181,15 @@ def train_resnet():
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||
# if we are doing beam search, run the first eval too
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
if (e + 1 - eval_start_epoch) % eval_epochs == 0 and steps_in_val_epoch > 0:
|
||||
train_step.reset() # free the train step memory :(
|
||||
if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
|
||||
eval_loss = []
|
||||
eval_times = []
|
||||
eval_top_1_acc = []
|
||||
@@ -214,7 +216,7 @@ def train_resnet():
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
eval_step.reset()
|
||||
if getenv("RESET_STEP", 1): eval_step.reset()
|
||||
total_loss = sum(eval_loss) / len(eval_loss)
|
||||
total_top_1 = sum(eval_top_1_acc) / len(eval_top_1_acc)
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
|
||||
Reference in New Issue
Block a user