mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test reduce graph permutations (#4291)
This commit is contained in:
@@ -465,6 +465,40 @@ class TestSchedule(unittest.TestCase):
|
||||
out1 = a.sum() + out0 + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_group_outside_reduce(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
# b.sum() is not a descendant of the fused nodes
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
check_schedule([out0, out1], 3) # TODO: this can fuse
|
||||
|
||||
def test_reduce_multiple_paths_fuse(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
out0 = a.sum().exp2()
|
||||
# out1 has two paths to a.sum()
|
||||
out1 = a.sum() + out0
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_reduce_multiple_paths_midreduce(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
r = a.sum()
|
||||
out0 = r.exp2()
|
||||
# reduce node in the indirect path from r to out2
|
||||
out1 = (a - out0).max()
|
||||
out2 = r + out1
|
||||
check_schedule([r, out0, out1, out2], 4)
|
||||
|
||||
def test_reduce_multiple_paths_midexpand(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4, 4)
|
||||
r = a.sum()
|
||||
out0 = r.exp2()
|
||||
# e1 is in the indirect path from a.sum() to out1
|
||||
e = b + out0
|
||||
out1 = r + e[0][0][0]
|
||||
check_schedule([r, out0, out1, e], 4)
|
||||
|
||||
def test_group_midreduce_nofuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
|
||||
Reference in New Issue
Block a user