mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
test_enet_se
This commit is contained in:
8
test/external/external_test_opt.py
vendored
8
test/external/external_test_opt.py
vendored
@@ -56,6 +56,14 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
with CLCache(51):
|
||||
model.forward(img).realize()
|
||||
|
||||
def test_enet_se(self):
|
||||
model = EfficientNet(getenv("ENET_NUM", 0), has_se=True)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
# TODO: this seems very high
|
||||
with CLCache(115):
|
||||
model.forward(img).realize()
|
||||
|
||||
def test_resnet(self):
|
||||
model = ResNet18()
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
|
||||
Reference in New Issue
Block a user