mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-27 03:01:52 -04:00
[BACKEND] Optimize wgmma with accumulator source equal to 0 (#2343)
Also add a test for MMA v3 reduction.
This commit is contained in:
@@ -149,3 +149,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: @dot_zero_acc
|
||||
// Generate a wgmma with 2 sources.
|
||||
// CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} {
|
||||
tt.func @dot_zero_acc(%a: tensor<128x64xf16, #shared>, %b: tensor<64x64xf16, #shared1>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
|
||||
%m = triton_nvidia_gpu.dot_async %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} :
|
||||
tensor<128x64xf16, #shared> * tensor<64x64xf16, #shared1> -> tensor<128x64xf32, #mma>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user