mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
* MaskRCNN weights loading * backbone maybe works * backbone works, but resnet body atol 1e-3 * RPN Call, but veryy wrong output * fixed topk * RPN maybe works, not sure about nms * Fix cursed modules * add back editorconfig * Full call, wrong output * Full call works * fix mask * use NMS from retinanet * Removing extra funcs * refactor * readable * Add example to run model * remove filter * Fix split, batched inference is worse * Fix image sizes * Matching reference * merge master * add filter on top detections * cuda backend fixed * add model eval and spec * convert images to rgb * fix eval * simplify examples code * remove extra code * meshgrid using tinygrad * removing numpy * roi align, floor, ceil * remove numpy from level_mapper * remove numpy from pooler * Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference" This reverts commit4b95a3cb49, reversing changes made to98f2b1fa2e. * roi align gather * fix master merge * revert to old floor, ceil as ints present in domain * use log2 op * fix indexes * weird bug with ints and gpu * weird bug with ints and gpu * refactors, add env var for gather * floor with contiguous, where * refactor topk, sort * remove staticmethod * refactor stride * remove log2 mlop * realize -> contiguous * refactor forward * remove num_classes, stride_in_1x1 from state * refactor forward * refactoring * flake8 * removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk * keep using tinygrad for smaller gathers * fix empty tensors * comms * move from tensor.py * resnet test passing * add coco dataset back * fix spaces * add test for log2 * no need to create Tensors * no need to create Tensors --------- Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
# load each model here, quick benchmark
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import GlobalCounters, getenv
|
|
import numpy as np
|
|
|
|
def test_model(model, *inputs):
|
|
GlobalCounters.reset()
|
|
out = model(*inputs)
|
|
if isinstance(out, Tensor): out = out.numpy()
|
|
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
|
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
|
|
|
|
def spec_resnet():
|
|
# Resnet50-v1.5
|
|
from models.resnet import ResNet50
|
|
mdl = ResNet50()
|
|
img = Tensor.randn(1, 3, 224, 224)
|
|
test_model(mdl, img)
|
|
|
|
def spec_retinanet():
|
|
# Retinanet with ResNet backbone
|
|
from models.resnet import ResNet50
|
|
from models.retinanet import RetinaNet
|
|
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
|
img = Tensor.randn(1, 3, 224, 224)
|
|
test_model(mdl, img)
|
|
|
|
def spec_unet3d():
|
|
# 3D UNET
|
|
from models.unet3d import UNet3D
|
|
mdl = UNet3D()
|
|
mdl.load_from_pretrained()
|
|
img = Tensor.randn(1, 1, 128, 128, 128)
|
|
test_model(mdl, img)
|
|
|
|
def spec_rnnt():
|
|
from models.rnnt import RNNT
|
|
mdl = RNNT()
|
|
mdl.load_from_pretrained()
|
|
x = Tensor.randn(220, 1, 240)
|
|
y = Tensor.randn(1, 220)
|
|
test_model(mdl, x, y)
|
|
|
|
def spec_bert():
|
|
from models.bert import BertForQuestionAnswering
|
|
mdl = BertForQuestionAnswering()
|
|
mdl.load_from_pretrained()
|
|
x = Tensor.randn(1, 384)
|
|
am = Tensor.randn(1, 384)
|
|
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
|
test_model(mdl, x, am, tt)
|
|
|
|
def spec_mrcnn():
|
|
from models.mask_rcnn import MaskRCNN, ResNet
|
|
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
|
mdl.load_from_pretrained()
|
|
x = Tensor.randn(3, 224, 224)
|
|
test_model(mdl, [x])
|
|
|
|
if __name__ == "__main__":
|
|
# inference only for now
|
|
Tensor.training = False
|
|
Tensor.no_grad = True
|
|
|
|
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
|
|
nm = f"spec_{m}"
|
|
if nm in globals():
|
|
print(f"testing {m}")
|
|
globals()[nm]()
|
|
|