test_enet_se

This commit is contained in:
George Hotz
2023-03-24 10:04:30 -07:00
parent fafe8e9ce2
commit 1cb5b2d015

View File

@@ -56,6 +56,14 @@ class TestInferenceMinKernels(unittest.TestCase):
with CLCache(51):
model.forward(img).realize()
def test_enet_se(self):
model = EfficientNet(getenv("ENET_NUM", 0), has_se=True)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
# TODO: this seems very high
with CLCache(115):
model.forward(img).realize()
def test_resnet(self):
model = ResNet18()
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))