mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 09:37:11 -05:00
more test opt
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
os.environ["LAZY"] = "1"
|
||||
os.environ["OPT"] = "2"
|
||||
if "OPT" not in os.environ:
|
||||
os.environ["OPT"] = "2"
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
@@ -104,5 +104,23 @@ class TestOpt(unittest.TestCase):
|
||||
print(img_conv)
|
||||
assert len(CL.CACHE) == 2, "optimizer didn't fold conv/elu"
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
with CLCache():
|
||||
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CL.CACHE) == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def test_fold_conv_relu_nobias(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
||||
with CLCache():
|
||||
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CL.CACHE) == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user