mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] add option to disable fp mul/add fusion (#2495)
By default, ptxas will enable fusion of mul/add to fma instructions. The backend was also being configured unconditionally to enable this on conversion from LLVM IR to PTX. This commit adds an option which can be used to disable the FP fusion behavior in both locations.
This commit is contained in:
@@ -3917,3 +3917,22 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
|
||||
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
torch.testing.assert_close(ref_out, C)
|
||||
|
||||
# -----------------------
|
||||
# test enable_fp_fusion
|
||||
# -----------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
|
||||
def test_enable_fp_fusion(enable_fp_fusion):
|
||||
# Sequential multiply add can be fused by backend
|
||||
@triton.jit
|
||||
def mul_add(data):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0)
|
||||
|
||||
data = torch.randn((128,), device='cuda', dtype=torch.float32)
|
||||
h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion)
|
||||
|
||||
found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None
|
||||
assert found_fma == enable_fp_fusion
|
||||
|
||||
Reference in New Issue
Block a user