mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge branch 'triton-mlir' into ifu-231117
This commit is contained in:
@@ -1075,15 +1075,16 @@ if TORCH_HAS_FP8E5B16:
|
||||
if TORCH_HAS_FP8E4B8:
|
||||
tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def gen_input(M, N, d_type, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if d_type == tl.float16:
|
||||
@@ -1246,7 +1247,8 @@ def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'c
|
||||
def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'):
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() != 3:
|
||||
backend = triton.common.backend.get_backend("hip")
|
||||
if backend.get_matrix_core_version() != 3:
|
||||
pytest.skip("fp8 data type is not available on hardware")
|
||||
|
||||
@triton.jit
|
||||
@@ -1630,7 +1632,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]]
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0 else
|
||||
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0 else
|
||||
# MFMA Test Dot tests
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim)
|
||||
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
|
||||
@@ -1881,7 +1883,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
if torch.version.hip is not None:
|
||||
if triton.language.semantic.gpu_matrix_core_version() > 0:
|
||||
backend = triton.common.backend.get_backend("hip")
|
||||
if backend.get_matrix_core_version() > 0:
|
||||
ttgir = pgm.asm['ttgir']
|
||||
if non_k_dim == 16:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
|
||||
@@ -1890,9 +1893,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
|
||||
gcn = pgm.asm['amdgcn']
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8:
|
||||
assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn
|
||||
return
|
||||
# make sure ld/st are vectorized
|
||||
@@ -2727,7 +2730,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
if transposeA and not transposeB:
|
||||
pytest.skip()
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0:
|
||||
if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0:
|
||||
pytest.skip("mfma is not available on hardware")
|
||||
|
||||
# source code for following ttgir:
|
||||
@@ -2817,7 +2820,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
kernel = triton.compile(f.name, device_type="hip", cc=capabilities)
|
||||
|
||||
import triton.language.semantic as sem
|
||||
if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
|
||||
# if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0:
|
||||
if torch.version.hip is not None and backend.get_matrix_core_version() > 0:
|
||||
kernel[(1, 1, 1)](x_tri, y_tri, z_tri)
|
||||
np.testing.assert_allclose(z_np, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@@ -65,8 +65,7 @@ def ttir_compute_capability_rewrite(mod, target):
|
||||
if _is_cuda(target):
|
||||
pm.add_rewrite_tensor_pointer_pass(target.capability, False)
|
||||
elif is_hip():
|
||||
capability = 90
|
||||
pm.add_rewrite_tensor_pointer_pass(capability, True)
|
||||
pm.add_rewrite_tensor_pointer_pass(target["capability"], True)
|
||||
else:
|
||||
assert(False, "unsupported target")
|
||||
pm.run(mod)
|
||||
@@ -118,14 +117,14 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
pm.add_tritongpu_accelerate_matmul_pass(capability)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
if is_hip():
|
||||
matrix_core_version = gpu_matrix_core_version()
|
||||
matrix_core_version = target["matrix_core_version"]
|
||||
matrix_inst_size = matrix_inst_type
|
||||
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
|
||||
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
ws_enabled = False
|
||||
@@ -191,7 +190,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
|
||||
if _is_cuda(target):
|
||||
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
|
||||
else:
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
||||
return translate_triton_gpu_to_llvmir(mod, target["capability"], TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
||||
|
||||
|
||||
# PTX translation
|
||||
@@ -360,8 +359,6 @@ def is_hip():
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return torch.version.hip is not None
|
||||
|
||||
from ..language.semantic import gpu_matrix_core_version
|
||||
|
||||
def get_cuda_capability(capability):
|
||||
if capability is None:
|
||||
device = get_current_device()
|
||||
|
||||
@@ -1188,32 +1188,6 @@ def is_hip():
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def gpu_matrix_core_version() -> int:
|
||||
""" Determine matrix core type available on current GPU.
|
||||
|
||||
0 means no tensor cores are available
|
||||
1 corresponds to MFMA in CDNA 1 architecture
|
||||
2 corresponds to MFMA in CDNA 2 architecture
|
||||
3 corresponds to MFMA in CDNA 3 architecture
|
||||
"""
|
||||
|
||||
if not is_hip():
|
||||
return 0
|
||||
arch_info = _triton.get_arch_info()
|
||||
gfx_arch_details = re.search('amd.*', arch_info)
|
||||
if gfx_arch_details is None:
|
||||
return 0
|
||||
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
|
||||
gpu_name = gfx_arch_details[1].split(':')[0]
|
||||
if gpu_name in ['gfx908']:
|
||||
return 1
|
||||
if gpu_name in ['gfx90a']:
|
||||
return 2
|
||||
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
|
||||
return 3
|
||||
return 0
|
||||
|
||||
def mfma_supported_granularity(m, n, k) -> bool:
|
||||
# todo make this gran_type matrix element type sensitive
|
||||
for gran_type in [(32, 8), (16, 16)]:
|
||||
@@ -1226,8 +1200,8 @@ def mfma_supported_granularity(m, n, k) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
matrix_core_version = gpu_matrix_core_version()
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool:
|
||||
matrix_core_version = target["matrix_core_version"]
|
||||
if matrix_core_version not in [1, 2, 3]:
|
||||
return False
|
||||
if not mfma_supported_granularity(M, N ,K):
|
||||
@@ -1240,10 +1214,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
|
||||
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
|
||||
# Checks for non-cuda archs
|
||||
if not _is_cuda(target):
|
||||
if is_hip():
|
||||
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \
|
||||
(lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \
|
||||
f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
|
||||
if lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8():
|
||||
assert lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16() or lhs.type.scalar.is_fp8e5(),\
|
||||
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
|
||||
assert rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16() or rhs.type.scalar.is_fp8e5(),\
|
||||
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
|
||||
return
|
||||
|
||||
if not _is_cuda(target):
|
||||
return
|
||||
|
||||
# Checks for cuda archs
|
||||
@@ -1287,13 +1269,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
|
||||
# hip for now converts fp8 to fp16 for mixed input
|
||||
if is_hip():
|
||||
fp8_supported = gpu_matrix_core_version() == 3
|
||||
target = builder.target
|
||||
assert "matrix_core_version" in target
|
||||
fp8_supported = target["matrix_core_version"] == 3
|
||||
# gfx940 data type
|
||||
lhs_hip_fp8 = lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16()
|
||||
rhs_hip_fp8 = rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16()
|
||||
lhs_fp8 = lhs.type.scalar.is_fp8()
|
||||
rhs_fp8 = rhs.type.scalar.is_fp8()
|
||||
supported_fp8_dot = fp8_supported and lhs_fp8 and rhs_fp8
|
||||
if not supported_fp8_dot and lhs_fp8:
|
||||
supported_fp8_dot = fp8_supported and lhs_hip_fp8 and rhs_hip_fp8
|
||||
if (not supported_fp8_dot) and lhs_fp8:
|
||||
lhs = cast(lhs, tl.float16, builder)
|
||||
if not supported_fp8_dot and rhs_fp8:
|
||||
if (not supported_fp8_dot) and rhs_fp8:
|
||||
rhs = cast(rhs, tl.float16, builder)
|
||||
|
||||
if lhs.type.scalar.is_int():
|
||||
@@ -1316,7 +1303,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
N = rhs.type.shape[1]
|
||||
|
||||
# Cast operands of types f16 and i8 for configurations where FMA only supported.
|
||||
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty):
|
||||
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target):
|
||||
# max_num_imprecise_acc does not yet apply to hip
|
||||
if is_hip():
|
||||
max_num_imprecise_acc = 0
|
||||
@@ -1334,7 +1321,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
|
||||
ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
|
||||
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
|
||||
# max_num_imprecise_acc does not yet apply to hip
|
||||
if is_hip():
|
||||
max_num_imprecise_acc = 0
|
||||
|
||||
33
python/triton/third_party/hip/hip_backend.py
vendored
33
python/triton/third_party/hip/hip_backend.py
vendored
@@ -273,6 +273,30 @@ def get_amdgcn_bitcode_paths(gfx_arch: str):
|
||||
return amdgcn_bitcode_paths
|
||||
|
||||
|
||||
def gpu_matrix_core_version() -> int:
|
||||
""" Determine matrix core type available on current GPU.
|
||||
|
||||
0 means no tensor cores are available
|
||||
1 corresponds to MFMA in CDNA 1 architecture
|
||||
2 corresponds to MFMA in CDNA 2 architecture
|
||||
3 corresponds to MFMA in CDNA 3 architecture
|
||||
"""
|
||||
|
||||
arch_info = _triton.get_arch_info()
|
||||
gfx_arch_details = re.search('amd.*', arch_info)
|
||||
if gfx_arch_details is None:
|
||||
return 0
|
||||
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
|
||||
gpu_name = gfx_arch_details[1].split(':')[0]
|
||||
if gpu_name in ['gfx908']:
|
||||
return 1
|
||||
if gpu_name in ['gfx90a']:
|
||||
return 2
|
||||
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
|
||||
return 3
|
||||
return 0
|
||||
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
# print("get_amdgpu_arch_fulldetails")
|
||||
"""
|
||||
@@ -294,7 +318,11 @@ def get_amdgpu_arch_fulldetails():
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
|
||||
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features}
|
||||
mat_core_ver = gpu_matrix_core_version()
|
||||
capability = gpu_matrix_core_version() * 100
|
||||
|
||||
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features,\
|
||||
"capability": capability, "matrix_core_version": mat_core_ver}
|
||||
except BaseException:
|
||||
return None
|
||||
|
||||
@@ -487,3 +515,6 @@ class HIPBackend(BaseBackend):
|
||||
return _triton.get_num_warps(module)
|
||||
else:
|
||||
return _triton.get_num_warps(module)
|
||||
|
||||
def get_matrix_core_version(self):
|
||||
return gpu_matrix_core_version()
|
||||
Reference in New Issue
Block a user