[MFMA] FP8 and BF8 support (#355)

* [MFMA] FP8 and BF8 support

This PR adds support of fp8 and bf8 in AccelerateMatmul pass and
Introduces generation of float8 mfma instructions in ttg to llvm conversion.

* add tests

* fix tests

* review fix: fix variable naming and dot operand promotion.

* review comments fixes

---------

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
Alexander Efimov
2023-10-25 20:27:10 +02:00
committed by GitHub
parent 8547694665
commit 5a86b46bb1
8 changed files with 199 additions and 30 deletions

View File

@@ -1492,6 +1492,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for allow_tf32 in [True, False]
for in_dtype, out_dtype in [('float16', 'float16'),
('bfloat16', 'float32'),
('float8e5m2fnuz', 'float32'),
('float8e4m3fnuz', 'float32'),
('float16', 'float32'),
('float32', 'float32')]
for non_k_dim in [0, 16, 32]
@@ -1531,6 +1533,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
capability = torch.cuda.get_device_capability()
if torch.version.hip is not None:
# TODO consider reenabling this tests when fp8 casing is fixed
if M == 16 and N == 16 and K == 16 and "float8" in in_dtype:
pytest.skip("triton do not generate MFMA instructions for given block size")
# set capability to large number to jump over check below
# check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests
if (M, N, K) == (128, 256, 32):
@@ -1563,6 +1569,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
in_dtype: tl.constexpr,
out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
@@ -1579,6 +1586,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5:
# TODO change types when they are available
# if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8:
x = x.to(in_dtype)
y = y.to(in_dtype)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
if ADD_MATRIX:
z += tl.load(Zs)
@@ -1599,6 +1611,27 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
tl.store(Zs, z)
# input
if in_dtype == "int8":
effective_in_dtype = tl.int8
elif in_dtype == "float32":
effective_in_dtype = tl.float32
elif in_dtype == "float16":
effective_in_dtype = tl.float16
elif in_dtype == "bfloat16":
effective_in_dtype = tl.bfloat16
elif in_dtype == "float8e5m2fnuz":
# TODO change types when they are available
effective_in_dtype = tl.float8e5
# effective_in_dtype = tl.float8e5b16
in_dtype = "float32"
elif in_dtype == "float8e4m3fnuz":
# TODO change types when they are available
effective_in_dtype = tl.float8e4b15
# effective_in_dtype = tl.float8e4b8
in_dtype = "float32"
else:
assert("unexpected in dtype")
rs = RandomState(17)
if col_a:
x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T
@@ -1616,6 +1649,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
if effective_in_dtype.is_fp8():
if effective_in_dtype.is_fp8e5():
mask = 0b111111000110 << 20
else:
mask = 0b111110000111 << 20
x = (x.view('uint32') & np.uint32(mask)).view('float32')
y = (y.view('uint32') & np.uint32(mask)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
@@ -1647,6 +1687,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
effective_in_dtype,
out_dtype,
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
@@ -1692,8 +1733,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
if torch.version.hip is not None:
import triton.language.semantic as sem
if sem.gpu_matrix_core_version() > 0:
if triton.language.semantic.gpu_matrix_core_version() > 0:
ttgir = pgm.asm['ttgir']
if non_k_dim == 16:
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
@@ -1701,6 +1741,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
elif non_k_dim == 32:
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
gcn = pgm.asm['amdgcn']
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn
return
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']