[FRONTEND] add option to disable fp mul/add fusion (#2495)

By default, ptxas will enable fusion of mul/add to fma instructions. The
backend was also being configured unconditionally to enable this on
conversion from LLVM IR to PTX. This commit adds an option which can be
used to disable the FP fusion behavior in both locations.
This commit is contained in:
Stewart Hall
2023-10-14 12:23:30 -07:00
committed by GitHub
parent 3b6ec763d5
commit 29828fe491
8 changed files with 109 additions and 77 deletions

View File

@@ -3917,3 +3917,22 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
else:
torch.testing.assert_close(ref_out, C)
# -----------------------
# test enable_fp_fusion
# -----------------------
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
def test_enable_fp_fusion(enable_fp_fusion):
# Sequential multiply add can be fused by backend
@triton.jit
def mul_add(data):
ptrs = data + tl.arange(0, 128)
tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0)
data = torch.randn((128,), device='cuda', dtype=torch.float32)
h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion)
found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None
assert found_fma == enable_fp_fusion