Dot slicing pass (#440)

* First commit

* Implement DotSlicing pass.

* small fixes

* Support chained dot in DotSlicingPass (second GEMM in FA)

* Add lit test for FA dot slicing

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
Co-authored-by: Ognjen <oplavsic@luxoft.com>
This commit is contained in:
oplavsic
2024-01-16 21:25:10 +01:00
committed by GitHub
parent a819e48435
commit 760ac8441a
10 changed files with 725 additions and 11 deletions

View File

@@ -12,6 +12,7 @@ std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
std::unique_ptr<Pass> createTritonAMDGPUDotSlicingPass(int sliceKTile = 0);
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);

View File

@@ -49,6 +49,26 @@ def TritonGPUStreamPipeline : Pass<"tritongpu-stream-pipeline", "mlir::ModuleOp"
"mlir::arith::ArithDialect"];
}
def TritonAMDGPUDotSlicing: Pass<"tritonamdgpu-dot-slicing", "mlir::ModuleOp"> {
let summary = "'DotOp' instruction slicing";
let description = [{
Slice 'DotOp' instruction into multiple smaller 'DotOp' instructions
in order to improve scheduling and latency hiding.
}];
let constructor = "mlir::createTritonAMDGPUDotSlicingPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"sliceKTile", "slice-k-tile",
"int32_t", /*default*/"0",
"slice size in k dimension">
];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";