diff --git a/test/external_test_opt.py b/test/external_test_opt.py index dc690d109f..16a43c898a 100644 --- a/test/external_test_opt.py +++ b/test/external_test_opt.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import os -os.environ["LAZY"] = "1" -os.environ["OPT"] = "2" +if "OPT" not in os.environ: + os.environ["OPT"] = "2" import gc import numpy as np @@ -104,5 +104,23 @@ class TestOpt(unittest.TestCase): print(img_conv) assert len(CL.CACHE) == 2, "optimizer didn't fold conv/elu" + def test_fold_conv_relu(self): + img = Tensor.ones(1,4,8,8) + c1 = nn.Conv2d(4, 4, kernel_size=3) + c2 = nn.Conv2d(4, 4, kernel_size=3) + with CLCache(): + img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() + print(img_conv) + assert len(CL.CACHE) == 2, "optimizer didn't fold conv/relu" + + def test_fold_conv_relu_nobias(self): + img = Tensor.ones(1,4,8,8) + c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False) + c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False) + with CLCache(): + img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() + print(img_conv) + assert len(CL.CACHE) == 2, "optimizer didn't fold conv/relu" + if __name__ == '__main__': unittest.main() \ No newline at end of file