mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rewrite the jit in the context of new schedule (#4162)
* rewrite the jit in the context of new schedule * mypy better * fix placeholder * tests * all functionality should work * fix tests * no CacheCollector
This commit is contained in:
89
test/external/external_test_opt.py
vendored
89
test/external/external_test_opt.py
vendored
@@ -10,24 +10,27 @@ from tinygrad.helpers import getenv
|
||||
from tinygrad.nn import optim
|
||||
#from tinygrad.lazy import PUSH_PERMUTES
|
||||
PUSH_PERMUTES = False
|
||||
from tinygrad.engine.jit import CacheCollector
|
||||
from tinygrad.engine.realize import capturing
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
|
||||
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
self.count = 0
|
||||
def add(self, ei): self.count += 1
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
GlobalCounters.reset()
|
||||
CacheCollector.start(self.var_vals)
|
||||
capturing.append(self)
|
||||
print("cache: entering")
|
||||
return self
|
||||
def __exit__(self, type, value, traceback):
|
||||
cache = CacheCollector.finish()
|
||||
print(f"cache: exiting with size {len(cache)}", f"allowed {self.allowed}" if self.allowed is not None else "")
|
||||
capturing.clear()
|
||||
print(f"cache: exiting with size {self.count}", f"allowed {self.allowed}" if self.allowed is not None else "")
|
||||
if self.allowed is not None:
|
||||
assert len(cache) <= self.allowed and (not self.strict or len(cache) == self.allowed), f"used too many kernels! {len(cache)} > {self.allowed}"
|
||||
assert self.count <= self.allowed and (not self.strict or self.count == self.allowed), f"used too many kernels! {self.count} > {self.allowed}"
|
||||
|
||||
from extra.models.convnext import ConvNeXt
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
@@ -77,9 +80,9 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
model = ViT(embed_dim=192, num_heads=3)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
with CLCache(222): # NOTE: this is way too high
|
||||
with CLCache(222) as cache: # NOTE: this is way too high
|
||||
out = model.forward(img)
|
||||
assert len(CacheCollector.cache) == 0, "ViT prerealized?"
|
||||
assert cache.count == 0, "ViT prerealized?"
|
||||
out.realize()
|
||||
|
||||
@unittest.skip("llama is fp16 but CI does not have fp16")
|
||||
@@ -97,12 +100,12 @@ class TestOptBinOp(unittest.TestCase):
|
||||
def _test_no_binop_rerun(self, f1, f2=None, allowed=1):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = f1(a, b)
|
||||
if f2 is not None: d = f2(a, b)
|
||||
c.realize()
|
||||
if f2 is not None: d.realize()
|
||||
assert len(CacheCollector.cache) == allowed, "binop was rerun!"
|
||||
assert cache.count == allowed, "binop was rerun!"
|
||||
if f2 is not None: np.testing.assert_allclose(c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5)
|
||||
|
||||
def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1))
|
||||
@@ -125,22 +128,22 @@ class TestOptReduceLoop(unittest.TestCase):
|
||||
def test_loop_left(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
t = a.sum(0)
|
||||
b = t.reshape(16,1).expand(16,16).sum(0)
|
||||
c = (t+b)
|
||||
c.realize()
|
||||
assert len(CacheCollector.cache) == 2, "loop left fusion broken"
|
||||
assert cache.count == 2, "loop left fusion broken"
|
||||
|
||||
def test_loop_right(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
t = a.sum(0)
|
||||
b = t.reshape(16,1).expand(16,16).sum(0)
|
||||
c = (b+t)
|
||||
c.realize()
|
||||
assert len(CacheCollector.cache) == 2, "loop right fusion broken"
|
||||
assert cache.count == 2, "loop right fusion broken"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptWChild(unittest.TestCase):
|
||||
@@ -148,12 +151,12 @@ class TestOptWChild(unittest.TestCase):
|
||||
def test_unrealized_child(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = (a*b).sum()
|
||||
d = c+1
|
||||
e = c+2 # noqa: F841
|
||||
d.realize()
|
||||
assert len(CacheCollector.cache) == 2, "don't fuse if you have children"
|
||||
assert cache.count == 2, "don't fuse if you have children"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOpt(unittest.TestCase):
|
||||
@@ -168,34 +171,34 @@ class TestOpt(unittest.TestCase):
|
||||
def test_fold_reduce_elementwise(self):
|
||||
img = Tensor.ones(32).contiguous()
|
||||
addme = Tensor.ones(1)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
ret = img.sum() + addme
|
||||
ret.realize()
|
||||
assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert cache.count == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert ret.item() == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,32,4,4).contiguous()
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
assert cache.count == 3, f"optimizer didn't fold batchnorm, got {cache.count}"
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(get_parameters(c1))
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
assert cache.count in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {cache.count}"
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -239,74 +242,74 @@ class TestOpt(unittest.TestCase):
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
# precache the bn
|
||||
bn(c1(img)).relu().realize()
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
bn(c1(img)).relu().realize()
|
||||
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
|
||||
assert cache.count == 1, f"optimizer didn't fold conv-batchnorm at test time, got {cache.count}"
|
||||
|
||||
def test_fold_conv_batchnorm(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||
assert cache.count == 4, f"optimizer didn't fold conv-batchnorm, got {cache.count}"
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/elu"
|
||||
assert cache.count == 2, "optimizer didn't fold conv/elu"
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
assert cache.count == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def test_fold_conv_relu_nobias(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
assert cache.count == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def test_permute_was_pushed(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache(2):
|
||||
with CLCache(2) as cache:
|
||||
c = a.sum(2)
|
||||
d = c.permute(1,0).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
def test_permute_was_pushed_through_contract_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache(2):
|
||||
with CLCache(2) as cache:
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,16).permute(1,0).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
def test_permute_was_pushed_through_contractw1s_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache(2):
|
||||
with CLCache(2) as cache:
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,1,16).transpose(2,1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
@@ -315,35 +318,35 @@ class TestOpt(unittest.TestCase):
|
||||
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
|
||||
def test_permute_was_pushed_through_expand_reshape(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = a.sum(2)
|
||||
d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
@unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES")
|
||||
def test_no_reduceop_rerun(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = a.sum(2)
|
||||
d = a.sum(2).permute(1,0)
|
||||
c.realize()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
@unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES")
|
||||
def test_no_reduceop_rerun_alt(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = a.sum(2).permute(1,0)
|
||||
d = a.sum(2)
|
||||
c.realize()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5)
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user