mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Dot slicing pass (#440)
* First commit * Implement DotSlicing pass. * small fixes * Support chained dot in DotSlicingPass (second GEMM in FA) * Add lit test for FA dot slicing --------- Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com> Co-authored-by: Ognjen <oplavsic@luxoft.com>
This commit is contained in:
@@ -12,6 +12,7 @@ std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
|
||||
std::unique_ptr<Pass> createTritonAMDGPUDotSlicingPass(int sliceKTile = 0);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
|
||||
|
||||
@@ -49,6 +49,26 @@ def TritonGPUStreamPipeline : Pass<"tritongpu-stream-pipeline", "mlir::ModuleOp"
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonAMDGPUDotSlicing: Pass<"tritonamdgpu-dot-slicing", "mlir::ModuleOp"> {
|
||||
let summary = "'DotOp' instruction slicing";
|
||||
|
||||
let description = [{
|
||||
Slice 'DotOp' instruction into multiple smaller 'DotOp' instructions
|
||||
in order to improve scheduling and latency hiding.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonAMDGPUDotSlicingPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"sliceKTile", "slice-k-tile",
|
||||
"int32_t", /*default*/"0",
|
||||
"slice size in k dimension">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user