clean external_test_opt.py (#2578)

This commit is contained in:
chenyu
2023-12-02 19:51:08 -05:00
committed by GitHub
parent 171543fc8d
commit 09c9794f3f
2 changed files with 30 additions and 45 deletions

View File

@@ -4,6 +4,8 @@ import os
import torch
if "OPT" not in os.environ:
os.environ["OPT"] = "2"
else:
assert int(os.environ["OPT"]) >= 2, "test is broken with OPT=0 or OPT=1"
import gc
import numpy as np
@@ -18,7 +20,8 @@ from tinygrad.lazy import PUSH_PERMUTES
from tinygrad.jit import CacheCollector
class CLCache:
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
def __enter__(self):
if self.preclear:
gc.collect()
@@ -42,7 +45,10 @@ from tinygrad.nn.state import get_parameters
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestInferenceMinKernels(unittest.TestCase):
def setUp(self):
self.training_old = Tensor.training
Tensor.training = False
def tearDown(self):
Tensor.training = self.training_old
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
def test_convnext(self):
@@ -155,12 +161,12 @@ class TestOptWChild(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOpt(unittest.TestCase):
def test_muladd(self):
a,b,c = [Tensor.ones(2,2) for _ in range(3)]
with CLCache():
a,b,c = [Tensor.randn(2,2).realize() for _ in range(3)]
na,nb,nc = a.numpy(),b.numpy(),c.numpy()
with CLCache(allowed=1):
d = a * b + c
d.realize()
assert len(CacheCollector.cache) == 1, "optimizer didn't fold muladd"
np.testing.assert_allclose(d.numpy(), np.ones((2,2))*2, rtol=1e-5)
np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5)
def test_fold_reduce_elementwise(self):
img = Tensor.ones(32)
@@ -169,7 +175,7 @@ class TestOpt(unittest.TestCase):
ret = img.sum() + addme
ret.realize()
assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise"
assert ret.numpy()[0] == 33
assert ret.item() == 33
def test_fold_batchnorm(self):
with Tensor.train():
@@ -179,7 +185,6 @@ class TestOpt(unittest.TestCase):
img_bn = bn(img).realize()
print(img_bn)
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
# Tensor.training = False
def test_fold_conv_sgd(self):
with Tensor.train():
@@ -194,7 +199,6 @@ class TestOpt(unittest.TestCase):
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
# Tensor.training = False
def test_fold_2convs_sgd(self):
with Tensor.train():
@@ -206,7 +210,6 @@ class TestOpt(unittest.TestCase):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
# Tensor.training = False
def test_fold_4convs_sgd(self):
with Tensor.train():
@@ -220,7 +223,6 @@ class TestOpt(unittest.TestCase):
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
opt.step()
# Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):
with Tensor.train():
@@ -228,12 +230,11 @@ class TestOpt(unittest.TestCase):
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = optim.SGD(get_parameters([c1, bn]))
with CLCache(allowed=18): # this is too high
with CLCache(allowed=17): # this is too high
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
opt.step()
# Tensor.training = False
def test_fold_conv_batchnorm_notrain(self):
img = Tensor.ones(1,3,8,8)
@@ -284,7 +285,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed(self):
a = Tensor.randn(16, 16, 16)
with CLCache():
with CLCache(2):
c = a.sum(2)
d = c.permute(1,0).contiguous()
d.realize()
@@ -294,7 +295,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed_through_contract_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache():
with CLCache(2):
c = a.sum(-1)
d = c.reshape(16,16).permute(1,0).contiguous()
d.realize()
@@ -304,7 +305,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed_through_contractw1s_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache():
with CLCache(2):
c = a.sum(-1)
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
d.realize()
@@ -352,21 +353,9 @@ class TestOpt(unittest.TestCase):
def test_fold_with_contiguous(self):
a = Tensor.randn(16, 16, 16)
b = Tensor.randn(16, 16)
with CLCache():
with CLCache(1):
c = (a.sum(2).contiguous() + b).contiguous()
c.realize()
cache_len = len(CacheCollector.cache)
assert cache_len == 1, "contiguous wasn't folded"
def _test_fold_expand_reduce_helper(self, n, m, axis, allowed):
b = torch.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis)
with CLCache(allowed=allowed):
a = Tensor.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis)
a.realize()
cache_len = len(CacheCollector.cache)
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
# TODO: what does these `return cache_len`` do?
return cache_len
def test_expand_reduce_is_folded_on_same_axis(self):
for axis in [0, 1]:
@@ -375,20 +364,16 @@ class TestOpt(unittest.TestCase):
with CLCache(allowed=2):
a = Tensor.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
a.realize()
cache_len = len(CacheCollector.cache)
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
return cache_len
def test_expand_reduce_is_not_folded_on_different_axes(self):
axis1, axis2 = 0, 1
for n in [4, 8, 16]:
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
with CLCache(allowed=3):
with CLCache(allowed=2):
a = Tensor.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
a.realize()
cache_len = len(CacheCollector.cache)
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
return cache_len
if __name__ == '__main__':
unittest.main()

View File

@@ -97,8 +97,6 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
return ret
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[Buffer]=None, base:Optional[LazyBuffer]=None):
@@ -276,16 +274,6 @@ class LazyBuffer:
# *** movement ops ***
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
if (root:=get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
@@ -337,11 +325,23 @@ class LazyBuffer:
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(a1*a2 for a1,a2 in zip(arg, self.op.arg)))
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
if (root:=get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
return y
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs = []
for x in srcs: