mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Support WMMA layout in TritonAMDGPUAccelerateMatmulPass
-Introduce WmmaEncodingAttr for WMMA output -Introduce BlockedToWMMA rewrite pattern in TritonAMDGPUAccelerateMatmulPass -Provide a flag tho check if wmma instructions are supported by target Signed-off-by: joviliast <iveselov.nn@gmail.com>
This commit is contained in:
@@ -1871,9 +1871,9 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
|
||||
})
|
||||
.def("add_tritonamdgpu_accelerate_matmul_pass",
|
||||
[](mlir::PassManager &self, int tensorCoreVersion, int instrSize) {
|
||||
[](mlir::PassManager &self, const std::string archGenName, int instrSize) {
|
||||
self.addPass(mlir::createTritonAMDGPUAccelerateMatmulPass(
|
||||
tensorCoreVersion, instrSize));
|
||||
archGenName, instrSize));
|
||||
})
|
||||
.def("add_tritongpu_optimize_dot_operands_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
|
||||
@@ -117,9 +117,9 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
pm.add_tritongpu_accelerate_matmul_pass(capability)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
if is_hip():
|
||||
matrix_core_version = target["matrix_core_version"]
|
||||
gfx_arch = target["gfx_arch"]
|
||||
matrix_inst_size = matrix_inst_type
|
||||
pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size)
|
||||
pm.add_tritonamdgpu_accelerate_matmul_pass(gfx_arch, matrix_inst_size)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
|
||||
Reference in New Issue
Block a user