mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
clean external_test_opt.py (#2578)
This commit is contained in:
51
test/external/external_test_opt.py
vendored
51
test/external/external_test_opt.py
vendored
@@ -4,6 +4,8 @@ import os
|
||||
import torch
|
||||
if "OPT" not in os.environ:
|
||||
os.environ["OPT"] = "2"
|
||||
else:
|
||||
assert int(os.environ["OPT"]) >= 2, "test is broken with OPT=0 or OPT=1"
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
@@ -18,7 +20,8 @@ from tinygrad.lazy import PUSH_PERMUTES
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
|
||||
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
@@ -42,7 +45,10 @@ from tinygrad.nn.state import get_parameters
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestInferenceMinKernels(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.training_old = Tensor.training
|
||||
Tensor.training = False
|
||||
def tearDown(self):
|
||||
Tensor.training = self.training_old
|
||||
|
||||
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
|
||||
def test_convnext(self):
|
||||
@@ -155,12 +161,12 @@ class TestOptWChild(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOpt(unittest.TestCase):
|
||||
def test_muladd(self):
|
||||
a,b,c = [Tensor.ones(2,2) for _ in range(3)]
|
||||
with CLCache():
|
||||
a,b,c = [Tensor.randn(2,2).realize() for _ in range(3)]
|
||||
na,nb,nc = a.numpy(),b.numpy(),c.numpy()
|
||||
with CLCache(allowed=1):
|
||||
d = a * b + c
|
||||
d.realize()
|
||||
assert len(CacheCollector.cache) == 1, "optimizer didn't fold muladd"
|
||||
np.testing.assert_allclose(d.numpy(), np.ones((2,2))*2, rtol=1e-5)
|
||||
np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5)
|
||||
|
||||
def test_fold_reduce_elementwise(self):
|
||||
img = Tensor.ones(32)
|
||||
@@ -169,7 +175,7 @@ class TestOpt(unittest.TestCase):
|
||||
ret = img.sum() + addme
|
||||
ret.realize()
|
||||
assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert ret.numpy()[0] == 33
|
||||
assert ret.item() == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
with Tensor.train():
|
||||
@@ -179,7 +185,6 @@ class TestOpt(unittest.TestCase):
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -194,7 +199,6 @@ class TestOpt(unittest.TestCase):
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -206,7 +210,6 @@ class TestOpt(unittest.TestCase):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_4convs_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -220,7 +223,6 @@ class TestOpt(unittest.TestCase):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
with Tensor.train():
|
||||
@@ -228,12 +230,11 @@ class TestOpt(unittest.TestCase):
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(get_parameters([c1, bn]))
|
||||
with CLCache(allowed=18): # this is too high
|
||||
with CLCache(allowed=17): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
# Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
@@ -284,7 +285,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache(2):
|
||||
c = a.sum(2)
|
||||
d = c.permute(1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -294,7 +295,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_contract_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache():
|
||||
with CLCache(2):
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,16).permute(1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -304,7 +305,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_contractw1s_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache():
|
||||
with CLCache(2):
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -352,21 +353,9 @@ class TestOpt(unittest.TestCase):
|
||||
def test_fold_with_contiguous(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache(1):
|
||||
c = (a.sum(2).contiguous() + b).contiguous()
|
||||
c.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
assert cache_len == 1, "contiguous wasn't folded"
|
||||
|
||||
def _test_fold_expand_reduce_helper(self, n, m, axis, allowed):
|
||||
b = torch.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis)
|
||||
with CLCache(allowed=allowed):
|
||||
a = Tensor.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis)
|
||||
a.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
# TODO: what does these `return cache_len`` do?
|
||||
return cache_len
|
||||
|
||||
def test_expand_reduce_is_folded_on_same_axis(self):
|
||||
for axis in [0, 1]:
|
||||
@@ -375,20 +364,16 @@ class TestOpt(unittest.TestCase):
|
||||
with CLCache(allowed=2):
|
||||
a = Tensor.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
|
||||
a.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
return cache_len
|
||||
|
||||
def test_expand_reduce_is_not_folded_on_different_axes(self):
|
||||
axis1, axis2 = 0, 1
|
||||
for n in [4, 8, 16]:
|
||||
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
with CLCache(allowed=3):
|
||||
with CLCache(allowed=2):
|
||||
a = Tensor.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
a.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
return cache_len
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -97,8 +97,6 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
return ret
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[Buffer]=None, base:Optional[LazyBuffer]=None):
|
||||
@@ -276,16 +274,6 @@ class LazyBuffer:
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
|
||||
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
if (root:=get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
@@ -337,11 +325,23 @@ class LazyBuffer:
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(a1*a2 for a1,a2 in zip(arg, self.op.arg)))
|
||||
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
|
||||
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
if (root:=get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
|
||||
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
return y
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
|
||||
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
new_srcs = []
|
||||
for x in srcs:
|
||||
|
||||
Reference in New Issue
Block a user