[MFMA] MI200 bfloat16 support (#294)

This PR enables bfloat16 support in MFMA dot on MI200.
Used mfma_f32_32x32x8bf16_1k instruction.
This commit is contained in:
Alexander Efimov
2023-08-18 14:28:18 +02:00
committed by GitHub
parent f7cf2c032b
commit 23979098c8
4 changed files with 30 additions and 7 deletions

View File

@@ -1233,6 +1233,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
for allow_tf32 in [True, False]
for in_dtype, out_dtype in [('float16', 'float16'),
('bfloat16', 'float32'),
('float16', 'float32'),
('float32', 'float32')]
if not (allow_tf32 and (in_dtype in ['float16']))] +
@@ -1264,7 +1265,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for allow_tf32 in [False, True]
for col_a in [True, False]
for col_b in [True, False]
for in_dtype in ['int8', 'float16', 'float32']
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
for out_dtype in [None]])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
@@ -1354,6 +1355,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
if in_dtype == 'bfloat16':
x_tri = x_tri.to(torch.bfloat16)
y_tri = y_tri.to(torch.bfloat16)
w_tri = w_tri.to(torch.bfloat16)
# triton result
if out_dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
@@ -1414,6 +1419,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
elif out_dtype == tl.float16:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
elif in_dtype == 'bfloat16':
# added atol, to loose precision for bfloat16xbfloat16->float32 case
# bfloat16 has less fraction bits than float16
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
else:
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)