all fusion opportunities (#4348)

This commit is contained in:
qazal
2024-04-29 19:32:23 +03:00
committed by GitHub
parent f363f39e83
commit cc1797673e

View File

@@ -530,6 +530,32 @@ class TestSchedule(unittest.TestCase):
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
def test_softmax_fusion(self):
out = Tensor.empty(4, 12, 64, 64).softmax()
check_schedule(out, 3)
def test_layernorm_onelayer_fusion(self):
layer = nn.LayerNorm([10, 10])
x = Tensor.empty(20, 5, 10, 10)
check_schedule(layer(x), 3)
def test_scaled_dot_product_attention_fusion(self):
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m)
check_schedule(out, 5)
def test_scaled_dot_product_attention_causal_fusion(self):
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m, is_causal=True)
check_schedule(out, 7)
def test_adam_step_fusion(self):
x = Tensor.empty(4, 64, 768)
layer = nn.Linear(768, 768*4)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize()