mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] FP8 and BF8 support (#355)
* [MFMA] FP8 and BF8 support This PR adds support of fp8 and bf8 in AccelerateMatmul pass and Introduces generation of float8 mfma instructions in ttg to llvm conversion. * add tests * fix tests * review fix: fix variable naming and dot operand promotion. * review comments fixes --------- Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
@@ -1492,6 +1492,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for allow_tf32 in [True, False]
|
||||
for in_dtype, out_dtype in [('float16', 'float16'),
|
||||
('bfloat16', 'float32'),
|
||||
('float8e5m2fnuz', 'float32'),
|
||||
('float8e4m3fnuz', 'float32'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
for non_k_dim in [0, 16, 32]
|
||||
@@ -1531,6 +1533,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# TODO consider reenabling this tests when fp8 casing is fixed
|
||||
if M == 16 and N == 16 and K == 16 and "float8" in in_dtype:
|
||||
pytest.skip("triton do not generate MFMA instructions for given block size")
|
||||
|
||||
# set capability to large number to jump over check below
|
||||
# check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests
|
||||
if (M, N, K) == (128, 256, 32):
|
||||
@@ -1563,6 +1569,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
in_dtype: tl.constexpr,
|
||||
out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
@@ -1579,6 +1586,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
if in_dtype is tl.float8e4b15 or in_dtype is tl.float8e5:
|
||||
# TODO change types when they are available
|
||||
# if in_dtype is tl.float8e5b16 or in_dtype is tl.float8e4b8:
|
||||
x = x.to(in_dtype)
|
||||
y = y.to(in_dtype)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
@@ -1599,6 +1611,27 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
if in_dtype == "int8":
|
||||
effective_in_dtype = tl.int8
|
||||
elif in_dtype == "float32":
|
||||
effective_in_dtype = tl.float32
|
||||
elif in_dtype == "float16":
|
||||
effective_in_dtype = tl.float16
|
||||
elif in_dtype == "bfloat16":
|
||||
effective_in_dtype = tl.bfloat16
|
||||
elif in_dtype == "float8e5m2fnuz":
|
||||
# TODO change types when they are available
|
||||
effective_in_dtype = tl.float8e5
|
||||
# effective_in_dtype = tl.float8e5b16
|
||||
in_dtype = "float32"
|
||||
elif in_dtype == "float8e4m3fnuz":
|
||||
# TODO change types when they are available
|
||||
effective_in_dtype = tl.float8e4b15
|
||||
# effective_in_dtype = tl.float8e4b8
|
||||
in_dtype = "float32"
|
||||
else:
|
||||
assert("unexpected in dtype")
|
||||
|
||||
rs = RandomState(17)
|
||||
if col_a:
|
||||
x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T
|
||||
@@ -1616,6 +1649,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
if effective_in_dtype.is_fp8():
|
||||
if effective_in_dtype.is_fp8e5():
|
||||
mask = 0b111111000110 << 20
|
||||
else:
|
||||
mask = 0b111110000111 << 20
|
||||
x = (x.view('uint32') & np.uint32(mask)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(mask)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
@@ -1647,6 +1687,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
effective_in_dtype,
|
||||
out_dtype,
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
@@ -1692,8 +1733,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
if torch.version.hip is not None:
|
||||
import triton.language.semantic as sem
|
||||
if sem.gpu_matrix_core_version() > 0:
|
||||
if triton.language.semantic.gpu_matrix_core_version() > 0:
|
||||
ttgir = pgm.asm['ttgir']
|
||||
if non_k_dim == 16:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
|
||||
@@ -1701,6 +1741,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
elif non_k_dim == 32:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
|
||||
gcn = pgm.asm['amdgcn']
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn
|
||||
return
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
Reference in New Issue
Block a user