From cc1797673efc0f066ad687cc58d248470394cedf Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 29 Apr 2024 19:32:23 +0300 Subject: [PATCH] all fusion opportunities (#4348) --- test/test_schedule.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index 10bebcbded..f35f967c7d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -530,6 +530,32 @@ class TestSchedule(unittest.TestCase): out1 = out0[0] + Tensor.empty(1, ) check_schedule([r, out0, out1], 3) + def test_softmax_fusion(self): + out = Tensor.empty(4, 12, 64, 64).softmax() + check_schedule(out, 3) + + def test_layernorm_onelayer_fusion(self): + layer = nn.LayerNorm([10, 10]) + x = Tensor.empty(20, 5, 10, 10) + check_schedule(layer(x), 3) + + def test_scaled_dot_product_attention_fusion(self): + x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4)) + out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m) + check_schedule(out, 5) + + def test_scaled_dot_product_attention_causal_fusion(self): + x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4)) + out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m, is_causal=True) + check_schedule(out, 7) + + def test_adam_step_fusion(self): + x = Tensor.empty(4, 64, 768) + layer = nn.Linear(768, 768*4) + opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4) + layer(x).relu().sum().backward() + check_schedule(opt.schedule_step(), 14) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): x = Tensor.ones(4).contiguous().realize()