mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
@@ -453,33 +453,28 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.contiguous() + y.contiguous()
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_group_fuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
def test_reduce_same_size(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
out2 = out0 * out1
|
||||
check_schedule([out0, out1, out2], 2)
|
||||
|
||||
def test_group_inner_deps_fuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + out0 + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_group_outside_reduce(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
# b.sum() is not a descendant of the fused nodes
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
check_schedule([out0, out1], 3) # TODO: this can fuse
|
||||
|
||||
def test_reduce_multiple_paths_fuse(self):
|
||||
def test_reduce_multiple_paths(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
out0 = a.sum().exp2()
|
||||
# out1 has two paths to a.sum()
|
||||
out1 = a.sum() + out0
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_reduce_ext_reduce_child(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
# b.sum() is not a descendant of the fused nodes
|
||||
out0 = a.sum() + b.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
check_schedule([out0, out1], 4)
|
||||
|
||||
def test_reduce_multiple_paths_midreduce(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
r = a.sum()
|
||||
@@ -489,6 +484,14 @@ class TestSchedule(unittest.TestCase):
|
||||
out2 = r + out1
|
||||
check_schedule([r, out0, out1, out2], 4)
|
||||
|
||||
def test_reduce_multiple_paths_midreduce_fused(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
out0 = a.sum() + 4
|
||||
out1 = b.max() + out0*2
|
||||
out2 = a.sum() + out1
|
||||
check_schedule([out0, out1, out2], 4)
|
||||
|
||||
def test_reduce_multiple_paths_midexpand(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4, 4)
|
||||
@@ -499,26 +502,33 @@ class TestSchedule(unittest.TestCase):
|
||||
out1 = r + e[0][0][0]
|
||||
check_schedule([r, out0, out1, e], 4)
|
||||
|
||||
def test_group_midreduce_nofuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
check_schedule([out0, out1], 3)
|
||||
|
||||
def test_group_midexpand_nofuse(self):
|
||||
def test_reduce_expand_child(self):
|
||||
a = Tensor.empty((32, 32, 32))
|
||||
b = Tensor.empty((1, 16))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + b
|
||||
check_schedule([out0, out1], 4)
|
||||
|
||||
def test_group_midshrink_fuse(self):
|
||||
def test_reduce_shrink_child(self):
|
||||
a = Tensor.empty(100, 100)
|
||||
b = Tensor.empty(10,)
|
||||
out0 = a.sum() + b[0]
|
||||
out1 = a.sum() + 2
|
||||
check_schedule([out0, out1], 1)
|
||||
c = a.sum() + b[0]
|
||||
d = a.sum() + 2
|
||||
check_schedule([c, d], 1)
|
||||
|
||||
def test_reduce_multiple_paths_midshrink(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
r = a.sum(axis=1)
|
||||
out0 = r.exp2()
|
||||
out1 = out0[0] + out0
|
||||
check_schedule([r, out0, out1], 3)
|
||||
|
||||
def test_reduce_shrink_output(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
r = a.sum(keepdim=True)
|
||||
out0 = r.exp2()
|
||||
out1 = out0[0] + Tensor.empty(1, )
|
||||
check_schedule([r, out0, out1], 3)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
|
||||
Reference in New Issue
Block a user