mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] MI200 bfloat16 support (#294)
This PR enables bfloat16 support in MFMA dot on MI200. Used mfma_f32_32x32x8bf16_1k instruction.
This commit is contained in:
@@ -1233,6 +1233,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
|
||||
for allow_tf32 in [True, False]
|
||||
for in_dtype, out_dtype in [('float16', 'float16'),
|
||||
('bfloat16', 'float32'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
@@ -1264,7 +1265,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for allow_tf32 in [False, True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for in_dtype in ['int8', 'float16', 'float32']
|
||||
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
|
||||
for out_dtype in [None]])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
@@ -1354,6 +1355,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
if in_dtype == 'bfloat16':
|
||||
x_tri = x_tri.to(torch.bfloat16)
|
||||
y_tri = y_tri.to(torch.bfloat16)
|
||||
w_tri = w_tri.to(torch.bfloat16)
|
||||
# triton result
|
||||
if out_dtype == 'int8':
|
||||
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
||||
@@ -1414,6 +1419,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
elif out_dtype == tl.float16:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
|
||||
elif in_dtype == 'bfloat16':
|
||||
# added atol, to loose precision for bfloat16xbfloat16->float32 case
|
||||
# bfloat16 has less fraction bits than float16
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
|
||||
else:
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user