Inference test: add tests for ResNet50 (#773)

* Add ResNet inference test and cannon

* Test with ResNet50

* test_car works with resnet fix
This commit is contained in:
Jacky Lee
2023-05-13 21:18:15 -07:00
committed by GitHub
parent e5b4b36cba
commit c552f6f92b

View File

@@ -7,9 +7,10 @@ import numpy as np
from PIL import Image
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from models.efficientnet import EfficientNet
from models.vit import ViT
from tinygrad.tensor import Tensor
from models.resnet import ResNet50
def _load_labels():
labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
@@ -92,5 +93,23 @@ class TestViT(unittest.TestCase):
label = _infer(self.model, car_img)
self.assertEqual(label, "racer, race car, racing car")
class TestResNet(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = ResNet50()
cls.model.load_from_pretrained()
@classmethod
def tearDownClass(cls):
del cls.model
def test_chicken(self):
label = _infer(self.model, chicken_img)
self.assertEqual(label, "hen")
def test_car(self):
label = _infer(self.model, car_img)
self.assertEqual(label, "sports car, sport car")
if __name__ == '__main__':
unittest.main()