mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add multi device support to retinanet eval
This commit is contained in:
@@ -81,35 +81,48 @@ def eval_unet3d():
|
||||
|
||||
def eval_retinanet():
|
||||
# RetinaNet with ResNeXt50_32X4D
|
||||
from extra.datasets.openimages import normalize
|
||||
from extra.datasets.openimages import normalize, download_dataset, iterate, BASEDIR
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
from extra.models.retinanet import RetinaNet
|
||||
mdl = RetinaNet(ResNeXt50_32X4D())
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
from extra.datasets.openimages import download_dataset, iterate, BASEDIR
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
class RetinaNetRunner:
|
||||
def __init__(self, bs, device=None):
|
||||
self.bs, self.device = bs, device
|
||||
self.mdl = RetinaNet(ResNeXt50_32X4D())
|
||||
self.mdlrun = TinyJit(lambda x: self.mdl(normalize(x, device=device)).realize())
|
||||
for x in get_parameters(self.mdl) if device else []: x.to_(device)
|
||||
if (fn:=getenv("RETINANET_MODEL", "")): load_state_dict(self.mdl, safe_load(fn))
|
||||
else: self.mdl.load_from_pretrained()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
if x.shape[0] == bs: return self.mdlrun(x)
|
||||
else:
|
||||
self.mdlrun._jit_cache = []
|
||||
return self.mdl(normalize(x, device=self.device))
|
||||
|
||||
def postprocess_detections(self, *args, **kwargs):
|
||||
return self.mdl.postprocess_detections(*args, **kwargs)
|
||||
|
||||
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
|
||||
for x in GPUS: Device[x]
|
||||
|
||||
coco = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), 'validation'))
|
||||
coco_eval = COCOeval(coco, iouType="bbox")
|
||||
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
|
||||
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
mdlrun = TinyJit(lambda x: mdl(normalize(x)).realize())
|
||||
n, bs = 0, getenv("BS", 8)
|
||||
mdl = RetinaNetRunner(bs, device=GPUS)
|
||||
|
||||
n, bs = 0, 8
|
||||
st = time.perf_counter()
|
||||
for x, targets in iterate(coco, base_dir, bs):
|
||||
dat = Tensor(x.astype(np.float32))
|
||||
x = Tensor(x.astype(np.float32), device=GPUS)
|
||||
mt = time.perf_counter()
|
||||
if dat.shape[0] == bs:
|
||||
outs = mdlrun(dat).numpy()
|
||||
else:
|
||||
mdlrun._jit_cache = []
|
||||
outs = mdl(normalize(dat)).numpy()
|
||||
outs = mdl(x).numpy()
|
||||
et = time.perf_counter()
|
||||
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
|
||||
predictions = mdl.postprocess_detections(outs, input_size=x.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
|
||||
ext = time.perf_counter()
|
||||
n += len(targets)
|
||||
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
|
||||
|
||||
Reference in New Issue
Block a user