use validation dataloader inside retinanet eval (#9747)

This commit is contained in:
Francis Lata
2025-04-05 16:46:55 -04:00
committed by GitHub
parent 5f7c79676f
commit 71b8890dd6
2 changed files with 38 additions and 32 deletions

View File

@@ -81,47 +81,42 @@ def eval_unet3d():
def eval_retinanet():
# RetinaNet with ResNeXt50_32X4D
from examples.mlperf.dataloader import batch_load_retinanet
from extra.datasets.openimages import normalize, download_dataset, BASEDIR
from extra.models.resnet import ResNeXt50_32X4D
from extra.models.retinanet import RetinaNet
mdl = RetinaNet(ResNeXt50_32X4D())
mdl.load_from_pretrained()
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
def input_fixup(x):
x = x.permute([0,3,1,2]) / 255.0
x -= input_mean
x /= input_std
return x
from extra.datasets.openimages import download_dataset, iterate, BASEDIR
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from contextlib import redirect_stdout
tlog("imports")
mdl = RetinaNet(ResNeXt50_32X4D())
mdl.load_from_pretrained()
tlog("loaded models")
coco = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), 'validation'))
coco_eval = COCOeval(coco, iouType="bbox")
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
tlog("loaded dataset")
from tinygrad.engine.jit import TinyJit
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
n, bs = 0, 8
iterator = batch_load_retinanet(coco, True, Path(base_dir), getenv("BS", 8), shuffle=False)
def data_get():
x, img_ids, img_sizes, cookie = next(iterator)
return x.to(Device.DEFAULT).realize(), img_ids, img_sizes, cookie
n = 0
proc = data_get()
tlog("loaded initial data")
st = time.perf_counter()
for x, targets in iterate(coco, base_dir, bs):
dat = Tensor(x.astype(np.float32))
mt = time.perf_counter()
if dat.shape[0] == bs:
outs = mdlrun(dat).numpy()
else:
mdlrun._jit_cache = []
outs = mdl(input_fixup(dat)).numpy()
et = time.perf_counter()
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
ext = time.perf_counter()
n += len(targets)
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
img_ids = [t["image_id"] for t in targets]
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box.tolist(), "score": score}
while proc is not None:
GlobalCounters.reset()
proc = (mdl(normalize(proc[0])), proc[1], proc[2], proc[3])
run = time.perf_counter()
# load the next data here
try: next_proc = data_get()
except StopIteration: next_proc = None
nd = time.perf_counter()
predictions, img_ids = mdl.postprocess_detections(proc[0].numpy(), orig_image_sizes=proc[2]), proc[1]
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
with redirect_stdout(None):
coco_eval.cocoDt = coco.loadRes(coco_results)
@@ -129,13 +124,18 @@ def eval_retinanet():
coco_eval.evaluate()
evaluated_imgs.extend(img_ids)
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
st = time.perf_counter()
n += len(proc[0])
et = time.perf_counter()
tlog(f"****** {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
st = et
proc, next_proc = next_proc, None
coco_eval.params.imgIds = evaluated_imgs
coco_eval._paramsEval.imgIds = evaluated_imgs
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
coco_eval.accumulate()
coco_eval.summarize()
tlog("done")
def eval_rnnt():
# RNN-T

View File

@@ -205,6 +205,12 @@ def resize(img:Image, tgt:dict[str, np.ndarray|tuple]|None=None, size:tuple[int,
return img, img_size
def normalize(img:Tensor, device:list[str]|None = None):
mean = Tensor([0.485, 0.456, 0.406], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
std = Tensor([0.229, 0.224, 0.225], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std
return img.cast(dtypes.default_float)
if __name__ == "__main__":
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")
download_dataset(base_dir, "validation")