Merge branch 'triton-mlir' into ifu-231117

This commit is contained in:
jayfurmanek
2023-11-27 07:44:04 -06:00
committed by GitHub
5 changed files with 265 additions and 168 deletions

View File

@@ -65,8 +65,7 @@ def ttir_compute_capability_rewrite(mod, target):
if _is_cuda(target):
pm.add_rewrite_tensor_pointer_pass(target.capability, False)
elif is_hip():
capability = 90
pm.add_rewrite_tensor_pointer_pass(capability, True)
pm.add_rewrite_tensor_pointer_pass(target["capability"], True)
else:
assert(False, "unsupported target")
pm.run(mod)
@@ -118,14 +117,14 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
pm.add_tritongpu_accelerate_matmul_pass(capability)
# TODO change interface of accelerate_matmul_pass
if is_hip():
matrix_core_version = gpu_matrix_core_version()
matrix_core_version = target["matrix_core_version"]
matrix_inst_size = matrix_inst_type
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
pm.add_tritongpu_remove_layout_conversions_pass()
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
ws_enabled = False
@@ -191,7 +190,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
if _is_cuda(target):
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
else:
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
return translate_triton_gpu_to_llvmir(mod, target["capability"], TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
# PTX translation
@@ -360,8 +359,6 @@ def is_hip():
raise ImportError("Triton requires PyTorch to be installed")
return torch.version.hip is not None
from ..language.semantic import gpu_matrix_core_version
def get_cuda_capability(capability):
if capability is None:
device = get_current_device()

View File

@@ -1188,32 +1188,6 @@ def is_hip():
raise ImportError("Triton requires PyTorch to be installed")
return torch.version.hip is not None
def gpu_matrix_core_version() -> int:
""" Determine matrix core type available on current GPU.
0 means no tensor cores are available
1 corresponds to MFMA in CDNA 1 architecture
2 corresponds to MFMA in CDNA 2 architecture
3 corresponds to MFMA in CDNA 3 architecture
"""
if not is_hip():
return 0
arch_info = _triton.get_arch_info()
gfx_arch_details = re.search('amd.*', arch_info)
if gfx_arch_details is None:
return 0
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
gpu_name = gfx_arch_details[1].split(':')[0]
if gpu_name in ['gfx908']:
return 1
if gpu_name in ['gfx90a']:
return 2
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
return 3
return 0
def mfma_supported_granularity(m, n, k) -> bool:
# todo make this gran_type matrix element type sensitive
for gran_type in [(32, 8), (16, 16)]:
@@ -1226,8 +1200,8 @@ def mfma_supported_granularity(m, n, k) -> bool:
return True
return False
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
matrix_core_version = gpu_matrix_core_version()
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool:
matrix_core_version = target["matrix_core_version"]
if matrix_core_version not in [1, 2, 3]:
return False
if not mfma_supported_granularity(M, N ,K):
@@ -1240,10 +1214,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
# Checks for non-cuda archs
if not _is_cuda(target):
if is_hip():
assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \
(lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \
f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
if lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8():
assert lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16() or lhs.type.scalar.is_fp8e5(),\
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
assert rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16() or rhs.type.scalar.is_fp8e5(),\
f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8"
return
if not _is_cuda(target):
return
# Checks for cuda archs
@@ -1287,13 +1269,18 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
# hip for now converts fp8 to fp16 for mixed input
if is_hip():
fp8_supported = gpu_matrix_core_version() == 3
target = builder.target
assert "matrix_core_version" in target
fp8_supported = target["matrix_core_version"] == 3
# gfx940 data type
lhs_hip_fp8 = lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16()
rhs_hip_fp8 = rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16()
lhs_fp8 = lhs.type.scalar.is_fp8()
rhs_fp8 = rhs.type.scalar.is_fp8()
supported_fp8_dot = fp8_supported and lhs_fp8 and rhs_fp8
if not supported_fp8_dot and lhs_fp8:
supported_fp8_dot = fp8_supported and lhs_hip_fp8 and rhs_hip_fp8
if (not supported_fp8_dot) and lhs_fp8:
lhs = cast(lhs, tl.float16, builder)
if not supported_fp8_dot and rhs_fp8:
if (not supported_fp8_dot) and rhs_fp8:
rhs = cast(rhs, tl.float16, builder)
if lhs.type.scalar.is_int():
@@ -1316,7 +1303,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
N = rhs.type.shape[1]
# Cast operands of types f16 and i8 for configurations where FMA only supported.
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty):
if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target):
# max_num_imprecise_acc does not yet apply to hip
if is_hip():
max_num_imprecise_acc = 0
@@ -1334,7 +1321,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
# max_num_imprecise_acc does not yet apply to hip
if is_hip():
max_num_imprecise_acc = 0

View File

@@ -273,6 +273,30 @@ def get_amdgcn_bitcode_paths(gfx_arch: str):
return amdgcn_bitcode_paths
def gpu_matrix_core_version() -> int:
""" Determine matrix core type available on current GPU.
0 means no tensor cores are available
1 corresponds to MFMA in CDNA 1 architecture
2 corresponds to MFMA in CDNA 2 architecture
3 corresponds to MFMA in CDNA 3 architecture
"""
arch_info = _triton.get_arch_info()
gfx_arch_details = re.search('amd.*', arch_info)
if gfx_arch_details is None:
return 0
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
gpu_name = gfx_arch_details[1].split(':')[0]
if gpu_name in ['gfx908']:
return 1
if gpu_name in ['gfx90a']:
return 2
if gpu_name in ['gfx940', 'gfx941', 'gfx942']:
return 3
return 0
def get_amdgpu_arch_fulldetails():
# print("get_amdgpu_arch_fulldetails")
"""
@@ -294,7 +318,11 @@ def get_amdgpu_arch_fulldetails():
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features}
mat_core_ver = gpu_matrix_core_version()
capability = gpu_matrix_core_version() * 100
return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features,\
"capability": capability, "matrix_core_version": mat_core_ver}
except BaseException:
return None
@@ -487,3 +515,6 @@ class HIPBackend(BaseBackend):
return _triton.get_num_warps(module)
else:
return _triton.get_num_warps(module)
def get_matrix_core_version(self):
return gpu_matrix_core_version()