mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Inference test: add tests for ResNet50 (#773)
* Add ResNet inference test and cannon * Test with ResNet50 * test_car works with resnet fix
This commit is contained in:
@@ -7,9 +7,10 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from models.efficientnet import EfficientNet
|
||||
from models.vit import ViT
|
||||
from tinygrad.tensor import Tensor
|
||||
from models.resnet import ResNet50
|
||||
|
||||
def _load_labels():
|
||||
labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
|
||||
@@ -92,5 +93,23 @@ class TestViT(unittest.TestCase):
|
||||
label = _infer(self.model, car_img)
|
||||
self.assertEqual(label, "racer, race car, racing car")
|
||||
|
||||
class TestResNet(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = ResNet50()
|
||||
cls.model.load_from_pretrained()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del cls.model
|
||||
|
||||
def test_chicken(self):
|
||||
label = _infer(self.model, chicken_img)
|
||||
self.assertEqual(label, "hen")
|
||||
|
||||
def test_car(self):
|
||||
label = _infer(self.model, car_img)
|
||||
self.assertEqual(label, "sports car, sport car")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user