mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
use hw for fp8 type conversion (#386)
* use hardware instruction for type conversion between fp8 and fp32 * move gpu_matrix_core_version from semantics.py to hip_backend.py --------- Co-authored-by: Aleksandr Efimov <efimov.alexander@gmail.com>
This commit is contained in:
@@ -1075,15 +1075,16 @@ if TORCH_HAS_FP8E5B16:
|
||||
if TORCH_HAS_FP8E4B8:
|
||||
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def gen_input(M, N, d_type, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if d_type == tl.float16:
|
||||
@@ -1246,7 +1247,8 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
|
||||
def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() != 3:
|
||||
backend = triton.common.backend.get_backend("hip")
|
||||
if backend.get_matrix_core_version() != 3:
|
||||
pytest.skip("fp8 data type is not available on hardware")
|
||||
|
||||
@triton.jit
|
||||
@@ -1630,7 +1632,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]]
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0 else
|
||||
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0 else
|
||||
# MFMA Test Dot tests
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim)
|
||||
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
|
||||
@@ -1881,7 +1883,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
if torch.version.hip is not None:
|
||||
if triton.language.semantic.gpu_matrix_core_version() > 0:
|
||||
backend = triton.common.backend.get_backend("hip")
|
||||
if backend.get_matrix_core_version() > 0:
|
||||
ttgir = pgm.asm['ttgir']
|
||||
if non_k_dim == 16:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
|
||||
@@ -1890,9 +1893,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
|
||||
gcn = pgm.asm['amdgcn']
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn
|
||||
return
|
||||
# make sure ld/st are vectorized
|
||||
@@ -2727,7 +2730,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
if transposeA and not transposeB:
|
||||
pytest.skip()
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0:
|
||||
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0:
|
||||
pytest.skip("mfma is not available on hardware")
|
||||
|
||||
# source code for following ttgir:
|
||||
@@ -2817,7 +2820,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
kernel = triton.compile(f.name, device_type="hip", cc=capabilities)
|
||||
|
||||
import triton.language.semantic as sem
|
||||
if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
|
||||
# if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
|
||||
if torch.version.hip is not None and backend.get_matrix_core_version() > 0:
|
||||
kernel[(1, 1, 1)](x_tri, y_tri, z_tri)
|
||||
np.testing.assert_allclose(z_np, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user