Files
tinygrad/examples/mlperf/model_spec.py
Kunwar Raj Singh 5d3310ce56 MaskRCNN Inference (#884)
* MaskRCNN weights loading

* backbone maybe works

* backbone works, but resnet body atol 1e-3

* RPN Call, but veryy wrong output

* fixed topk

* RPN maybe works, not sure about nms

* Fix cursed modules

* add back editorconfig

* Full call, wrong output

* Full call works

* fix mask

* use NMS from retinanet

* Removing extra funcs

* refactor

* readable

* Add example to run model

* remove filter

* Fix split, batched inference is worse

* Fix image sizes

* Matching reference

* merge master

* add filter on top detections

* cuda backend fixed

* add model eval and spec

* convert images to rgb

* fix eval

* simplify examples code

* remove extra code

* meshgrid using tinygrad

* removing numpy

* roi align, floor, ceil

* remove numpy from level_mapper

* remove numpy from pooler

* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"

This reverts commit 4b95a3cb49, reversing
changes made to 98f2b1fa2e.

* roi align gather

* fix master merge

* revert to old floor, ceil as ints present in domain

* use log2 op

* fix indexes

* weird bug with ints and gpu

* weird bug with ints and gpu

* refactors, add env var for gather

* floor with contiguous, where

* refactor topk, sort

* remove staticmethod

* refactor stride

* remove log2 mlop

* realize -> contiguous

* refactor forward

* remove num_classes, stride_in_1x1 from state

* refactor forward

* refactoring

* flake8

* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk

* keep using tinygrad for smaller gathers

* fix empty tensors

* comms

* move from tensor.py

* resnet test passing

* add coco dataset back

* fix spaces

* add test for log2

* no need to create Tensors

* no need to create Tensors

---------

Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-25 15:37:51 -07:00

71 lines
1.9 KiB
Python

# load each model here, quick benchmark
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters, getenv
import numpy as np
def test_model(model, *inputs):
GlobalCounters.reset()
out = model(*inputs)
if isinstance(out, Tensor): out = out.numpy()
# TODO: return event future to still get the time_sum_s without DEBUG=2
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
def spec_resnet():
# Resnet50-v1.5
from models.resnet import ResNet50
mdl = ResNet50()
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
def spec_retinanet():
# Retinanet with ResNet backbone
from models.resnet import ResNet50
from models.retinanet import RetinaNet
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
def spec_unet3d():
# 3D UNET
from models.unet3d import UNet3D
mdl = UNet3D()
mdl.load_from_pretrained()
img = Tensor.randn(1, 1, 128, 128, 128)
test_model(mdl, img)
def spec_rnnt():
from models.rnnt import RNNT
mdl = RNNT()
mdl.load_from_pretrained()
x = Tensor.randn(220, 1, 240)
y = Tensor.randn(1, 220)
test_model(mdl, x, y)
def spec_bert():
from models.bert import BertForQuestionAnswering
mdl = BertForQuestionAnswering()
mdl.load_from_pretrained()
x = Tensor.randn(1, 384)
am = Tensor.randn(1, 384)
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
test_model(mdl, x, am, tt)
def spec_mrcnn():
from models.mask_rcnn import MaskRCNN, ResNet
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
mdl.load_from_pretrained()
x = Tensor.randn(3, 224, 224)
test_model(mdl, [x])
if __name__ == "__main__":
# inference only for now
Tensor.training = False
Tensor.no_grad = True
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
nm = f"spec_{m}"
if nm in globals():
print(f"testing {m}")
globals()[nm]()