mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
schedule2, keep the tests working with small changes (#1932)
* lazy cleanups * ast functions take in LazyOps * op instead of self.op * _base for mops * fix contiguous * start schedule * test_schedule * fix openpilot * more tests * bugfix and test skip * work * make sure things get freed * fix zerosized tensors * fix failing test * fix ceil and friends * fix openpilot * disable training * disable test collectives
This commit is contained in:
333
test/test_schedule.py
Normal file
333
test/test_schedule.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# this will be the new test_ops for the next level
|
||||
# schedule confirms the right things are capable of fusing
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.helpers import DEBUG, dtypes
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad import nn
|
||||
|
||||
def check_schedule(t:Tensor, allowed:int):
|
||||
sched = [s for s in t.lazydata.schedule() if s[0].op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
from extra.utils import print_tree
|
||||
for i, s in enumerate(sched):
|
||||
print("op", i)
|
||||
print_tree(s[0])
|
||||
assert len(sched) == allowed
|
||||
# test the ops linearize
|
||||
for s in sched:
|
||||
l = Linearizer(s[0])
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = Tensor.empty(10)
|
||||
d = a+b+c
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_basic_binop_fusion_deep(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = Tensor.empty(10)
|
||||
d = Tensor.empty(10)
|
||||
e = a+b+c+d
|
||||
check_schedule(e, 1)
|
||||
|
||||
def test_mulacc_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = (a*b).sum()
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_mulacc_relu_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = (a*b).sum().relu()
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_binop_reshape_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = Tensor.empty(5,2)
|
||||
d = (a+b).reshape(5,2)+c
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_binop_permute_fusion(self):
|
||||
a = Tensor.empty(2,5)
|
||||
b = Tensor.empty(2,5)
|
||||
c = Tensor.empty(5,2)
|
||||
d = (a+b).permute(1,0)+c
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_binop_elu_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = a.elu()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_binop_reshape_reduce_fusion(self):
|
||||
a = Tensor.empty(100)
|
||||
b = Tensor.empty(100)
|
||||
c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True)
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_reduce_reshape_binop_fusion(self):
|
||||
a = Tensor.empty(10,10)
|
||||
b = Tensor.empty(10)
|
||||
c = a.sum(axis=0) + b
|
||||
check_schedule(c, 1)
|
||||
|
||||
@unittest.skip("not pushing permutes through reduces")
|
||||
def test_reduce_permute_binop_fusion(self):
|
||||
a = Tensor.empty(10,10,10)
|
||||
b = Tensor.empty(10,10,1)
|
||||
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_binop_early_reshape_reduce_fusion(self):
|
||||
a = Tensor.empty(100)
|
||||
b = Tensor.empty(100)
|
||||
c = Tensor.empty(10,10)
|
||||
d = ((a+b).reshape(10,10) + c).sum(axis=0)
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_diamond_folded(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = Tensor.empty(10)
|
||||
d = Tensor.empty(10)
|
||||
ab = a+b
|
||||
e = (ab+c) + (ab+d)
|
||||
check_schedule(e, 1)
|
||||
|
||||
def test_cache_binaryop(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a+b
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_cache_binaryop_reshaped(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
|
||||
def test_cache_binaryop_transpose(self):
|
||||
a = Tensor.empty(10,10)
|
||||
b = Tensor.empty(10,10)
|
||||
c = (a.T*b.T).T #.contiguous()
|
||||
d = a*b
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_cache_binaryop_transpose_realized(self):
|
||||
a = Tensor.randn(10,10).realize()
|
||||
b = Tensor.randn(10,10).realize()
|
||||
c = (a.T*b.T).T
|
||||
d = a*b
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
|
||||
def test_cache_two_reduceops(self):
|
||||
a = Tensor.empty(10)
|
||||
b = a.sum()
|
||||
c = a.sum()
|
||||
bc = b+c
|
||||
check_schedule(bc, 1)
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
y = Tensor.empty(2)
|
||||
out = y.sum(keepdim=True).sqrt().__neg__()
|
||||
check_schedule(out, 1)
|
||||
|
||||
#@unittest.skip("may want to reconsider this")
|
||||
def test_fold_batchnorm(self):
|
||||
Tensor.training = True
|
||||
img = Tensor.empty(1,32,4,4)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
out = bn(img)
|
||||
check_schedule(out, 3)
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
c1.weight.realize()
|
||||
c1.bias.realize()
|
||||
|
||||
# run
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
out = c1(img).relu()
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
c1.weight.realize()
|
||||
c1.bias.realize()
|
||||
|
||||
# run
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
out = c1(img).elu()
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_two_sum(self):
|
||||
img = Tensor.empty(64,64)
|
||||
x = (img.sum(0) + img.sum(1))
|
||||
out = x.relu()
|
||||
del x # is 3 without this
|
||||
check_schedule(out, 2)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_push_permute_through_reshape(self):
|
||||
a = Tensor.empty(16,16)
|
||||
b = Tensor.empty(16,16)
|
||||
c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous()
|
||||
check_schedule(c, 1)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_push_permute_through_reshape_alt(self):
|
||||
a = Tensor.empty(4,4,4,4)
|
||||
b = Tensor.empty(4,4,4,4)
|
||||
c = (a+b).reshape(16,16).permute(1,0).contiguous()
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_no_binop_rerun(self):
|
||||
a = Tensor.empty(16)
|
||||
b = Tensor.empty(16)
|
||||
c = a+b
|
||||
d = (a+b).reshape(16,1)
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
|
||||
def test_multi_permute_should_collapse(self):
|
||||
a = Tensor.empty(4,4,4,4)
|
||||
b = Tensor.empty(16)
|
||||
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
|
||||
check_schedule(c, 1)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_fancy_reshape_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
out = c.sum() + d.sum()
|
||||
check_schedule(out, 1)
|
||||
|
||||
"""
|
||||
def test_reshape_doesnt_matter(self):
|
||||
a = Tensor.empty(10)
|
||||
b = a.reshape(10,1)
|
||||
self.assertIs(a.lazydata.backing, b.lazydata.backing)
|
||||
|
||||
def test_permute_doesnt_matter(self):
|
||||
a = Tensor.empty(10, 10)
|
||||
b = a.permute(1,0)
|
||||
c = a.reshape(10, 1, 10).permute(2,1,0)
|
||||
self.assertIs(b.lazydata.backing, c.lazydata.backing)
|
||||
"""
|
||||
|
||||
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
|
||||
@unittest.skip("not real world")
|
||||
def test_children_dont_push(self):
|
||||
a = Tensor.empty(10, 10, 1)
|
||||
b = Tensor.empty(10, 10, 1)
|
||||
d = (a+b).expand(10, 10, 10)
|
||||
e = (a+b).permute(2,1,0)
|
||||
f = d+e
|
||||
check_schedule(f, 2)
|
||||
|
||||
def test_dont_fuse_binops_with_children(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = Tensor.empty(10)
|
||||
keep_me = a+b
|
||||
e = keep_me.sum() # give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
|
||||
d = keep_me+c
|
||||
check_schedule(d, 2)
|
||||
d.realize()
|
||||
check_schedule(keep_me, 0)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_permute_breaks_fusion(self):
|
||||
a = Tensor.empty(10, 10, 10)
|
||||
b = Tensor.empty(10, 10)
|
||||
c = (a.sum(axis=2) + b).permute(1,0)
|
||||
d = c.permute(1,0)
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_some_permute_fusion(self):
|
||||
a = Tensor.empty(8192, 16)
|
||||
b = Tensor.empty(1, 16)
|
||||
d = (a.T + b.expand(8192, 16).T)
|
||||
c = a + b.expand(8192, 16)
|
||||
e = d.T
|
||||
check_schedule(c, 1)
|
||||
check_schedule(e, 1)
|
||||
|
||||
# this is the failing case in openpilot...it's very simple like this
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_image_conv_fusion(self):
|
||||
from tinygrad.nn.image import image_conv2d
|
||||
w1 = Tensor.empty(16, 16, 1, 1)
|
||||
b1 = Tensor.empty(16)
|
||||
w2 = Tensor.empty(16, 16, 1, 1)
|
||||
b2 = Tensor.empty(16)
|
||||
w3 = Tensor.empty(16, 16, 1, 1)
|
||||
b3 = Tensor.empty(16)
|
||||
|
||||
x = Tensor.empty(1, 16, 32, 32)
|
||||
x = base = image_conv2d(x, w1, b1)
|
||||
x = image_conv2d(x, w2, b2) + base
|
||||
x = image_conv2d(x, w3, b3)
|
||||
|
||||
# NOOP, 3 convs, contiguous
|
||||
check_schedule(x, 5)
|
||||
|
||||
@unittest.skip("failing now with contig")
|
||||
def test_image_conv_fusion_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
b2 = Tensor.empty(16)
|
||||
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
|
||||
|
||||
x = Tensor.empty(16, 32)
|
||||
x = base = p(x) + b1.reshape(16,1)
|
||||
x = p(x)
|
||||
x = x + b2.reshape(16,1)
|
||||
x = x + base
|
||||
del base
|
||||
x = p(x)
|
||||
check_schedule(x, 4)
|
||||
|
||||
def test_image_conv_fusion_more_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
|
||||
|
||||
x = Tensor.empty(16, 32)
|
||||
x = base = p(x) + b1.reshape(16,1)
|
||||
x = p(x)
|
||||
del base
|
||||
check_schedule(x, 3)
|
||||
|
||||
def test_resnet_block(self):
|
||||
from models.resnet import BasicBlock
|
||||
Tensor.training = False
|
||||
bb = BasicBlock(64,64)
|
||||
|
||||
x = Tensor.empty(1, 64, 32, 32)
|
||||
out = bb(x)
|
||||
check_schedule(out, 4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user