Support WMMA layout in TritonAMDGPUAccelerateMatmulPass

-Introduce WmmaEncodingAttr for WMMA output
-Introduce BlockedToWMMA rewrite pattern in TritonAMDGPUAccelerateMatmulPass
-Provide a flag tho check if wmma instructions are supported by target

Signed-off-by: joviliast <iveselov.nn@gmail.com>
This commit is contained in:
joviliast
2023-12-11 20:53:06 +02:00
committed by Lixun Zhang
parent b7a412d82a
commit af15da2f84
9 changed files with 258 additions and 23 deletions

View File

@@ -1871,9 +1871,9 @@ void init_triton_ir(py::module &&m) {
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
})
.def("add_tritonamdgpu_accelerate_matmul_pass",
[](mlir::PassManager &self, int tensorCoreVersion, int instrSize) {
[](mlir::PassManager &self, const std::string archGenName, int instrSize) {
self.addPass(mlir::createTritonAMDGPUAccelerateMatmulPass(
tensorCoreVersion, instrSize));
archGenName, instrSize));
})
.def("add_tritongpu_optimize_dot_operands_pass",
[](mlir::PassManager &self) {

View File

@@ -117,9 +117,9 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
pm.add_tritongpu_accelerate_matmul_pass(capability)
# TODO change interface of accelerate_matmul_pass
if is_hip():
matrix_core_version = target["matrix_core_version"]
gfx_arch = target["gfx_arch"]
matrix_inst_size = matrix_inst_type
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
pm.add_tritonamdgpu_accelerate_matmul_pass(gfx_arch, matrix_inst_size)
pm.add_tritongpu_remove_layout_conversions_pass()
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()