diff --git a/test/models/test_efficientnet.py b/test/models/test_efficientnet.py index d7f0ea1821..45f509a379 100644 --- a/test/models/test_efficientnet.py +++ b/test/models/test_efficientnet.py @@ -7,9 +7,10 @@ import numpy as np from PIL import Image from tinygrad.helpers import getenv +from tinygrad.tensor import Tensor from models.efficientnet import EfficientNet from models.vit import ViT -from tinygrad.tensor import Tensor +from models.resnet import ResNet50 def _load_labels(): labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt' @@ -92,5 +93,23 @@ class TestViT(unittest.TestCase): label = _infer(self.model, car_img) self.assertEqual(label, "racer, race car, racing car") +class TestResNet(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = ResNet50() + cls.model.load_from_pretrained() + + @classmethod + def tearDownClass(cls): + del cls.model + + def test_chicken(self): + label = _infer(self.model, chicken_img) + self.assertEqual(label, "hen") + + def test_car(self): + label = _infer(self.model, car_img) + self.assertEqual(label, "sports car, sport car") + if __name__ == '__main__': unittest.main()