mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
RetinaNet INITMLPERF support (#9950)
* fixes to make fake data work * fix eval beam * fix merge issue
This commit is contained in:
@@ -361,6 +361,7 @@ def train_retinanet():
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
BASE_DIR = getenv("BASE_DIR", BASEDIR)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
INITMLPERF = getenv("INITMLPERF")
|
||||
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
|
||||
|
||||
for x in GPUS: Device[x]
|
||||
@@ -379,27 +380,37 @@ def train_retinanet():
|
||||
|
||||
x, y_boxes, y_labels, matches, anchors, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie
|
||||
|
||||
def _fake_data_get(bs:int, val:bool=False):
|
||||
x = Tensor.zeros(bs, 800, 800, 3, dtype=dtypes.uint8).contiguous()
|
||||
if val:
|
||||
img_ids, img_sizes = [0] * bs, [(800, 800)] * bs
|
||||
return x.shard(GPUS, axis=0).realize(), img_ids, img_sizes, None
|
||||
|
||||
y_boxes = Tensor.zeros(bs, 120087, 4, dtype=dtypes.float32).contiguous()
|
||||
y_labels = Tensor.zeros(bs, 120087, dtype=dtypes.int64).contiguous()
|
||||
matches = Tensor.ones(bs, 120087, dtype=dtypes.int64).contiguous()
|
||||
anchors = Tensor.zeros(bs, 120087, 4, dtype=dtypes.float64).contiguous()
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), None
|
||||
|
||||
@TinyJit
|
||||
def _train_step(model, optim, loss_scaler, x, **kwargs):
|
||||
with Context(BEAM=TRAIN_BEAM):
|
||||
optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optim.params: t.grad = t.grad / loss_scaler
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optim.params: t.grad = t.grad / loss_scaler
|
||||
|
||||
optim.step()
|
||||
optim.step()
|
||||
|
||||
return loss.realize(), losses
|
||||
return loss.realize(), losses
|
||||
|
||||
@TinyJit
|
||||
def _eval_step(model, x, **kwargs):
|
||||
with Context(BEAM=EVAL_BEAM):
|
||||
out = model(normalize(x, GPUS), **kwargs)
|
||||
return out.realize()
|
||||
out = model(normalize(x, GPUS), **kwargs)
|
||||
return out.realize()
|
||||
|
||||
# ** hyperparameters **
|
||||
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
@@ -441,13 +452,16 @@ def train_retinanet():
|
||||
optim = Adam(params, lr=lr)
|
||||
|
||||
# ** dataset **
|
||||
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
|
||||
if INITMLPERF:
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = EVAL_BS
|
||||
else:
|
||||
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
|
||||
|
||||
# ** lr scheduler **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
@@ -458,9 +472,14 @@ def train_retinanet():
|
||||
|
||||
for e in range(start_epoch, EPOCHS):
|
||||
# ** training loop **
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
BEAM.value = TRAIN_BEAM
|
||||
|
||||
if INITMLPERF:
|
||||
i, proc = 0, _fake_data_get(BS)
|
||||
else:
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
prev_cookies = []
|
||||
st = time.perf_counter()
|
||||
@@ -475,7 +494,10 @@ def train_retinanet():
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it)
|
||||
if INITMLPERF:
|
||||
next_proc = _fake_data_get(BS)
|
||||
else:
|
||||
next_proc = _data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
@@ -520,15 +542,20 @@ def train_retinanet():
|
||||
|
||||
# ** eval loop **
|
||||
if (e + 1) % eval_freq == 0:
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
|
||||
with Tensor.train(mode=False), Tensor.test():
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
if INITMLPERF:
|
||||
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
|
||||
else:
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
|
||||
|
||||
eval_times, prev_cookies = [], []
|
||||
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
@@ -536,20 +563,25 @@ def train_retinanet():
|
||||
|
||||
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
|
||||
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
|
||||
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
|
||||
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
|
||||
|
||||
with redirect_stdout(None):
|
||||
coco_val.cocoDt = val_dataset.loadRes(coco_results)
|
||||
coco_val.params.imgIds = img_ids
|
||||
coco_val.evaluate()
|
||||
if not INITMLPERF:
|
||||
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
|
||||
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
|
||||
|
||||
val_img_ids.extend(img_ids)
|
||||
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
with redirect_stdout(None):
|
||||
coco_val.cocoDt = val_dataset.loadRes(coco_results)
|
||||
coco_val.params.imgIds = img_ids
|
||||
coco_val.evaluate()
|
||||
|
||||
val_img_ids.extend(img_ids)
|
||||
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it, val=val)
|
||||
if INITMLPERF:
|
||||
next_proc = _fake_data_get(EVAL_BS, val=val)
|
||||
else:
|
||||
next_proc = _data_get(it, val=val)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
@@ -566,22 +598,23 @@ def train_retinanet():
|
||||
if getenv("RESET_STEP", 1): _eval_step.reset()
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
|
||||
coco_val.params.imgIds = val_img_ids
|
||||
coco_val._paramsEval.imgIds = val_img_ids
|
||||
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
|
||||
coco_val.accumulate()
|
||||
coco_val.summarize()
|
||||
if not INITMLPERF:
|
||||
coco_val.params.imgIds = val_img_ids
|
||||
coco_val._paramsEval.imgIds = val_img_ids
|
||||
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
|
||||
coco_val.accumulate()
|
||||
coco_val.summarize()
|
||||
|
||||
val_metric = coco_val.stats[0]
|
||||
val_metric = coco_val.stats[0]
|
||||
|
||||
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
|
||||
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
|
||||
if WANDB:
|
||||
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
|
||||
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
|
||||
break
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
|
||||
break
|
||||
|
||||
def train_unet3d():
|
||||
"""
|
||||
|
||||
@@ -160,7 +160,7 @@ class RegressionHead:
|
||||
self.bbox_reg = ConvHead(in_channels, num_anchors * 4, kernel_size=3, padding=1)
|
||||
|
||||
if box_coder is None:
|
||||
box_coder = BoxCoder((1.0, 1.0, 1.0, 1.0), apply_to_remove=False)
|
||||
box_coder = BoxCoder((1.0, 1.0, 1.0, 1.0))
|
||||
self.box_coder = box_coder
|
||||
|
||||
def __call__(self, x:Tensor, bboxes:Tensor|None=None, matches:Tensor|None=None, anchors:Tensor|None=None):
|
||||
|
||||
Reference in New Issue
Block a user