estimated resnet training time for BENCHMARK (#3769)

This commit is contained in:
chenyu
2024-03-15 22:36:58 -04:00
committed by GitHub
parent 0870dd5b3b
commit e1c5aa9cce

View File

@@ -84,6 +84,8 @@ def train_resnet():
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args)
BENCHMARK = getenv("BENCHMARK")
# ** jitted steps **
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
# mlperf reference resnet does not divide by input_std for some reason
@@ -112,11 +114,12 @@ def train_resnet():
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie
# ** epoch loop **
step_times = []
for e in range(start_epoch, epochs):
# ** train loop **
Tensor.training = True
it = iter(tqdm(batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e),
total=steps_in_train_epoch, desc=f"epoch {e}", disable=getenv("BENCHMARK")))
total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, data_get(it)
st = time.perf_counter()
while proc is not None:
@@ -136,6 +139,8 @@ def train_resnet():
loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / BS
cl = time.perf_counter()
if BENCHMARK:
step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
@@ -150,7 +155,11 @@ def train_resnet():
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == getenv("BENCHMARK"): return
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_hours = median_step_time * steps_in_train_epoch * epochs / 60 / 60
print(f"Estimated training time: {estimated_total_hours:.0f}h{(estimated_total_hours - int(estimated_total_hours)) * 60:.0f}m")
return
# ** eval loop **
if (e + 1 - eval_start_epoch) % eval_epochs == 0:
@@ -233,5 +242,3 @@ if __name__ == "__main__":
if nm in globals():
print(f"training {m}")
globals()[nm]()