RetinaNet INITMLPERF support (#9950)

* fixes to make fake data work

* fix eval beam

* fix merge issue
This commit is contained in:
Francis Lata
2025-04-21 10:32:05 -04:00
committed by GitHub
parent 014f870733
commit d7e247f329
2 changed files with 80 additions and 47 deletions

View File

@@ -361,6 +361,7 @@ def train_retinanet():
NUM_CLASSES = len(MLPERF_CLASSES)
BASE_DIR = getenv("BASE_DIR", BASEDIR)
BENCHMARK = getenv("BENCHMARK")
INITMLPERF = getenv("INITMLPERF")
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
for x in GPUS: Device[x]
@@ -379,27 +380,37 @@ def train_retinanet():
x, y_boxes, y_labels, matches, anchors, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie
def _fake_data_get(bs:int, val:bool=False):
x = Tensor.zeros(bs, 800, 800, 3, dtype=dtypes.uint8).contiguous()
if val:
img_ids, img_sizes = [0] * bs, [(800, 800)] * bs
return x.shard(GPUS, axis=0).realize(), img_ids, img_sizes, None
y_boxes = Tensor.zeros(bs, 120087, 4, dtype=dtypes.float32).contiguous()
y_labels = Tensor.zeros(bs, 120087, dtype=dtypes.int64).contiguous()
matches = Tensor.ones(bs, 120087, dtype=dtypes.int64).contiguous()
anchors = Tensor.zeros(bs, 120087, 4, dtype=dtypes.float64).contiguous()
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), None
@TinyJit
def _train_step(model, optim, loss_scaler, x, **kwargs):
with Context(BEAM=TRAIN_BEAM):
optim.zero_grad()
optim.zero_grad()
losses = model(normalize(x, GPUS), **kwargs)
loss = sum([l for l in losses.values()])
losses = model(normalize(x, GPUS), **kwargs)
loss = sum([l for l in losses.values()])
(loss * loss_scaler).backward()
for t in optim.params: t.grad = t.grad / loss_scaler
(loss * loss_scaler).backward()
for t in optim.params: t.grad = t.grad / loss_scaler
optim.step()
optim.step()
return loss.realize(), losses
return loss.realize(), losses
@TinyJit
def _eval_step(model, x, **kwargs):
with Context(BEAM=EVAL_BEAM):
out = model(normalize(x, GPUS), **kwargs)
return out.realize()
out = model(normalize(x, GPUS), **kwargs)
return out.realize()
# ** hyperparameters **
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
@@ -441,13 +452,16 @@ def train_retinanet():
optim = Adam(params, lr=lr)
# ** dataset **
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
if INITMLPERF:
config["steps_in_train_epoch"] = steps_in_train_epoch = BS
config["steps_in_val_epoch"] = steps_in_val_epoch = EVAL_BS
else:
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
# ** lr scheduler **
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
# ** initialize wandb **
if (WANDB:=getenv("WANDB")):
@@ -458,9 +472,14 @@ def train_retinanet():
for e in range(start_epoch, EPOCHS):
# ** training loop **
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, _data_get(it)
BEAM.value = TRAIN_BEAM
if INITMLPERF:
i, proc = 0, _fake_data_get(BS)
else:
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, _data_get(it)
prev_cookies = []
st = time.perf_counter()
@@ -475,7 +494,10 @@ def train_retinanet():
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = _data_get(it)
if INITMLPERF:
next_proc = _fake_data_get(BS)
else:
next_proc = _data_get(it)
except StopIteration:
next_proc = None
@@ -520,15 +542,20 @@ def train_retinanet():
# ** eval loop **
if (e + 1) % eval_freq == 0:
BEAM.value = EVAL_BEAM
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False), Tensor.test():
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
i, proc = 0, _data_get(it, val=val)
if INITMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else:
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
i, proc = 0, _data_get(it, val=val)
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
eval_times, prev_cookies = [], []
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
while proc is not None:
GlobalCounters.reset()
@@ -536,20 +563,25 @@ def train_retinanet():
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
with redirect_stdout(None):
coco_val.cocoDt = val_dataset.loadRes(coco_results)
coco_val.params.imgIds = img_ids
coco_val.evaluate()
if not INITMLPERF:
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
val_img_ids.extend(img_ids)
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
with redirect_stdout(None):
coco_val.cocoDt = val_dataset.loadRes(coco_results)
coco_val.params.imgIds = img_ids
coco_val.evaluate()
val_img_ids.extend(img_ids)
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = _data_get(it, val=val)
if INITMLPERF:
next_proc = _fake_data_get(EVAL_BS, val=val)
else:
next_proc = _data_get(it, val=val)
except StopIteration:
next_proc = None
@@ -566,22 +598,23 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _eval_step.reset()
total_fw_time = sum(eval_times) / len(eval_times)
coco_val.params.imgIds = val_img_ids
coco_val._paramsEval.imgIds = val_img_ids
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
coco_val.accumulate()
coco_val.summarize()
if not INITMLPERF:
coco_val.params.imgIds = val_img_ids
coco_val._paramsEval.imgIds = val_img_ids
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
coco_val.accumulate()
coco_val.summarize()
val_metric = coco_val.stats[0]
val_metric = coco_val.stats[0]
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
if WANDB:
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
if WANDB:
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
break
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
break
def train_unet3d():
"""

View File

@@ -160,7 +160,7 @@ class RegressionHead:
self.bbox_reg = ConvHead(in_channels, num_anchors * 4, kernel_size=3, padding=1)
if box_coder is None:
box_coder = BoxCoder((1.0, 1.0, 1.0, 1.0), apply_to_remove=False)
box_coder = BoxCoder((1.0, 1.0, 1.0, 1.0))
self.box_coder = box_coder
def __call__(self, x:Tensor, bboxes:Tensor|None=None, matches:Tensor|None=None, anchors:Tensor|None=None):