mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
clean external_test_opt.py (#2578)
This commit is contained in:
51
test/external/external_test_opt.py
vendored
51
test/external/external_test_opt.py
vendored
@@ -4,6 +4,8 @@ import os
|
||||
import torch
|
||||
if "OPT" not in os.environ:
|
||||
os.environ["OPT"] = "2"
|
||||
else:
|
||||
assert int(os.environ["OPT"]) >= 2, "test is broken with OPT=0 or OPT=1"
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
@@ -18,7 +20,8 @@ from tinygrad.lazy import PUSH_PERMUTES
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
|
||||
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
@@ -42,7 +45,10 @@ from tinygrad.nn.state import get_parameters
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestInferenceMinKernels(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.training_old = Tensor.training
|
||||
Tensor.training = False
|
||||
def tearDown(self):
|
||||
Tensor.training = self.training_old
|
||||
|
||||
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
|
||||
def test_convnext(self):
|
||||
@@ -155,12 +161,12 @@ class TestOptWChild(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOpt(unittest.TestCase):
|
||||
def test_muladd(self):
|
||||
a,b,c = [Tensor.ones(2,2) for _ in range(3)]
|
||||
with CLCache():
|
||||
a,b,c = [Tensor.randn(2,2).realize() for _ in range(3)]
|
||||
na,nb,nc = a.numpy(),b.numpy(),c.numpy()
|
||||
with CLCache(allowed=1):
|
||||
d = a * b + c
|
||||
d.realize()
|
||||
assert len(CacheCollector.cache) == 1, "optimizer didn't fold muladd"
|
||||
np.testing.assert_allclose(d.numpy(), np.ones((2,2))*2, rtol=1e-5)
|
||||
np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5)
|
||||
|
||||
def test_fold_reduce_elementwise(self):
|
||||
img = Tensor.ones(32)
|
||||
@@ -169,7 +175,7 @@ class TestOpt(unittest.TestCase):
|
||||
ret = img.sum() + addme
|
||||
ret.realize()
|
||||
assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert ret.numpy()[0] == 33
|
||||
assert ret.item() == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
with Tensor.train():
|
||||
@@ -179,7 +185,6 @@ class TestOpt(unittest.TestCase):
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -194,7 +199,6 @@ class TestOpt(unittest.TestCase):
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -206,7 +210,6 @@ class TestOpt(unittest.TestCase):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_4convs_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -220,7 +223,6 @@ class TestOpt(unittest.TestCase):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -228,12 +230,11 @@ class TestOpt(unittest.TestCase):
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(get_parameters([c1, bn]))
|
||||
with CLCache(allowed=18): # this is too high
|
||||
with CLCache(allowed=17): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
@@ -284,7 +285,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache(2):
|
||||
c = a.sum(2)
|
||||
d = c.permute(1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -294,7 +295,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_contract_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache():
|
||||
with CLCache(2):
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,16).permute(1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -304,7 +305,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_contractw1s_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache():
|
||||
with CLCache(2):
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -352,21 +353,9 @@ class TestOpt(unittest.TestCase):
|
||||
def test_fold_with_contiguous(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache(1):
|
||||
c = (a.sum(2).contiguous() + b).contiguous()
|
||||
c.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
assert cache_len == 1, "contiguous wasn't folded"
|
||||
|
||||
def _test_fold_expand_reduce_helper(self, n, m, axis, allowed):
|
||||
b = torch.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis)
|
||||
with CLCache(allowed=allowed):
|
||||
a = Tensor.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis)
|
||||
a.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
# TODO: what does these `return cache_len`` do?
|
||||
return cache_len
|
||||
|
||||
def test_expand_reduce_is_folded_on_same_axis(self):
|
||||
for axis in [0, 1]:
|
||||
@@ -375,20 +364,16 @@ class TestOpt(unittest.TestCase):
|
||||
with CLCache(allowed=2):
|
||||
a = Tensor.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
|
||||
a.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
return cache_len
|
||||
|
||||
def test_expand_reduce_is_not_folded_on_different_axes(self):
|
||||
axis1, axis2 = 0, 1
|
||||
for n in [4, 8, 16]:
|
||||
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
with CLCache(allowed=3):
|
||||
with CLCache(allowed=2):
|
||||
a = Tensor.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
a.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
return cache_len
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user