add multi device support to retinanet eval

This commit is contained in:
Francis Lata
2025-01-29 10:00:46 -08:00
parent dcd1941b94
commit 335d11281c

View File

@@ -81,35 +81,48 @@ def eval_unet3d():
def eval_retinanet():
# RetinaNet with ResNeXt50_32X4D
from extra.datasets.openimages import normalize
from extra.datasets.openimages import normalize, download_dataset, iterate, BASEDIR
from extra.models.resnet import ResNeXt50_32X4D
from extra.models.retinanet import RetinaNet
mdl = RetinaNet(ResNeXt50_32X4D())
mdl.load_from_pretrained()
from extra.datasets.openimages import download_dataset, iterate, BASEDIR
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from contextlib import redirect_stdout
class RetinaNetRunner:
def __init__(self, bs, device=None):
self.bs, self.device = bs, device
self.mdl = RetinaNet(ResNeXt50_32X4D())
self.mdlrun = TinyJit(lambda x: self.mdl(normalize(x, device=device)).realize())
for x in get_parameters(self.mdl) if device else []: x.to_(device)
if (fn:=getenv("RETINANET_MODEL", "")): load_state_dict(self.mdl, safe_load(fn))
else: self.mdl.load_from_pretrained()
def __call__(self, x:Tensor) -> Tensor:
if x.shape[0] == bs: return self.mdlrun(x)
else:
self.mdlrun._jit_cache = []
return self.mdl(normalize(x, device=self.device))
def postprocess_detections(self, *args, **kwargs):
return self.mdl.postprocess_detections(*args, **kwargs)
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
for x in GPUS: Device[x]
coco = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), 'validation'))
coco_eval = COCOeval(coco, iouType="bbox")
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
from tinygrad.engine.jit import TinyJit
mdlrun = TinyJit(lambda x: mdl(normalize(x)).realize())
n, bs = 0, getenv("BS", 8)
mdl = RetinaNetRunner(bs, device=GPUS)
n, bs = 0, 8
st = time.perf_counter()
for x, targets in iterate(coco, base_dir, bs):
dat = Tensor(x.astype(np.float32))
x = Tensor(x.astype(np.float32), device=GPUS)
mt = time.perf_counter()
if dat.shape[0] == bs:
outs = mdlrun(dat).numpy()
else:
mdlrun._jit_cache = []
outs = mdl(normalize(dat)).numpy()
outs = mdl(x).numpy()
et = time.perf_counter()
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
predictions = mdl.postprocess_detections(outs, input_size=x.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
ext = time.perf_counter()
n += len(targets)
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")