mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add support for BENCHMARK
This commit is contained in:
@@ -359,6 +359,7 @@ def train_retinanet():
|
||||
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
BASE_DIR = getenv("BASE_DIR", BASEDIR)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
||||
|
||||
for x in GPUS: Device[x]
|
||||
@@ -438,11 +439,12 @@ def train_retinanet():
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), bs) // bs
|
||||
step_times = []
|
||||
|
||||
# ** training loop **
|
||||
for e in range(1, num_epochs + 1):
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=bs, seed=seed)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}"))
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
# if e < LR_WARMUP_EPOCHS:
|
||||
@@ -474,6 +476,7 @@ def train_retinanet():
|
||||
loss = loss.item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
@@ -491,6 +494,15 @@ def train_retinanet():
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
assert not math.isnan(loss)
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * e / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||
return
|
||||
|
||||
def train_unet3d():
|
||||
"""
|
||||
Trains the UNet3D model.
|
||||
|
||||
Reference in New Issue
Block a user