mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
estimated resnet training time for BENCHMARK (#3769)
This commit is contained in:
@@ -84,6 +84,8 @@ def train_resnet():
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args)
|
||||
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
# ** jitted steps **
|
||||
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
|
||||
# mlperf reference resnet does not divide by input_std for some reason
|
||||
@@ -112,11 +114,12 @@ def train_resnet():
|
||||
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie
|
||||
|
||||
# ** epoch loop **
|
||||
step_times = []
|
||||
for e in range(start_epoch, epochs):
|
||||
# ** train loop **
|
||||
Tensor.training = True
|
||||
it = iter(tqdm(batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e),
|
||||
total=steps_in_train_epoch, desc=f"epoch {e}", disable=getenv("BENCHMARK")))
|
||||
total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, data_get(it)
|
||||
st = time.perf_counter()
|
||||
while proc is not None:
|
||||
@@ -136,6 +139,8 @@ def train_resnet():
|
||||
loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / BS
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK:
|
||||
step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
@@ -150,7 +155,11 @@ def train_resnet():
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
|
||||
if i == getenv("BENCHMARK"): return
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_hours = median_step_time * steps_in_train_epoch * epochs / 60 / 60
|
||||
print(f"Estimated training time: {estimated_total_hours:.0f}h{(estimated_total_hours - int(estimated_total_hours)) * 60:.0f}m")
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
if (e + 1 - eval_start_epoch) % eval_epochs == 0:
|
||||
@@ -233,5 +242,3 @@ if __name__ == "__main__":
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
globals()[nm]()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user