multireduce scheduler tests (#5141)

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
Timmy
2024-07-08 10:28:55 -07:00
committed by GitHub
parent bb2222e488
commit bb7746985f

View File

@@ -687,6 +687,81 @@ class TestSchedule(unittest.TestCase):
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
# multireduce spec
def test_std_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.std(-1)
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_argmin_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmin(-1)
run_schedule(check_schedule(out, 3))
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
# multireduce spec
def test_argmax_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmax(-1)
run_schedule(check_schedule(out, 3))
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
# multireduce spec
def test_scaled_dot_product_attention_multireduce_fusion(self):
Tensor.manual_seed(0)
q = Tensor.randn(32,8,16,64).realize()
k = Tensor.randn(32,8,16,64).realize()
v = Tensor.randn(32,8,16,64).realize()
out = Tensor.scaled_dot_product_attention(q,k,v)
check_schedule(out, 5) # correctness checked in test_ops
# multireduce spec
def test_ugly_reduceop_pairing(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
b = Tensor.randn(4, 32).realize()
c = Tensor.randn(4, 32).realize()
out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), \
(c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_expand_reduce_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out = (a+a.sum(-1, keepdim=True)).sum(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_expand_reduce_expand_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), \
a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_branching_reduces_and_expands_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out0 = a+a.sum(-1, keepdim=True)
out1 = out0.sum(-1)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule([out0, out1], 3))
np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_simple_sequential(self):
Tensor.manual_seed(0)