diff --git a/test/test_schedule.py b/test/test_schedule.py index 37e1fff4c8..98ff6c7943 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -687,6 +687,81 @@ class TestSchedule(unittest.TestCase): out1 = out0[0] + Tensor.empty(1, ) check_schedule([r, out0, out1], 3) + # multireduce spec + def test_std_multireduce_fusion(self): + Tensor.manual_seed(0) + x = Tensor.randn(4, 32).realize() + out = x.std(-1) + run_schedule(check_schedule(out, 2)) + np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) + + # multireduce spec + def test_argmin_multireduce_fusion(self): + Tensor.manual_seed(0) + x = Tensor.randn(4, 32).realize() + out = x.argmin(-1) + run_schedule(check_schedule(out, 3)) + np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1)) + + # multireduce spec + def test_argmax_multireduce_fusion(self): + Tensor.manual_seed(0) + x = Tensor.randn(4, 32).realize() + out = x.argmax(-1) + run_schedule(check_schedule(out, 3)) + np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1)) + + # multireduce spec + def test_scaled_dot_product_attention_multireduce_fusion(self): + Tensor.manual_seed(0) + q = Tensor.randn(32,8,16,64).realize() + k = Tensor.randn(32,8,16,64).realize() + v = Tensor.randn(32,8,16,64).realize() + out = Tensor.scaled_dot_product_attention(q,k,v) + check_schedule(out, 5) # correctness checked in test_ops + + # multireduce spec + def test_ugly_reduceop_pairing(self): + Tensor.manual_seed(0) + a = Tensor.randn(4, 32).realize() + b = Tensor.randn(4, 32).realize() + c = Tensor.randn(4, 32).realize() + out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse + # run_schedule(check_schedule(out, 1)) + run_schedule(check_schedule(out, 3)) + np.testing.assert_allclose(out.numpy(), \ + (c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4) + + # multireduce spec + def test_reduce_expand_reduce_fusion(self): + Tensor.manual_seed(0) + a = Tensor.randn(4, 32).realize() + out = (a+a.sum(-1, keepdim=True)).sum(-1) + # run_schedule(check_schedule(out, 1)) + run_schedule(check_schedule(out, 2)) + np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) + + # multireduce spec + def test_reduce_expand_reduce_expand_fusion(self): + Tensor.manual_seed(0) + a = Tensor.randn(4, 32).realize() + out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True) + # run_schedule(check_schedule(out, 2)) + run_schedule(check_schedule(out, 3)) + np.testing.assert_allclose(out.numpy(), \ + a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4) + + # multireduce spec + def test_branching_reduces_and_expands_fusion(self): + Tensor.manual_seed(0) + a = Tensor.randn(4, 32).realize() + out0 = a+a.sum(-1, keepdim=True) + out1 = out0.sum(-1) + # run_schedule(check_schedule(out, 2)) + run_schedule(check_schedule([out0, out1], 3)) + np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) + # multireduce spec def test_multireduce_fusion_simple_sequential(self): Tensor.manual_seed(0)