[BACKEND] Optimize wgmma with accumulator source equal to 0 (#2343)

Also add a test for MMA v3 reduction.
This commit is contained in:
Thomas Raoux
2023-09-20 14:05:12 -07:00
committed by GitHub
parent ed5a53057d
commit 9cab885dff
3 changed files with 35 additions and 3 deletions

View File

@@ -149,3 +149,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.return
}
}
// -----
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: @dot_zero_acc
// Generate a wgmma with 2 sources.
// CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} {
tt.func @dot_zero_acc(%a: tensor<128x64xf16, #shared>, %b: tensor<64x64xf16, #shared1>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%m = triton_nvidia_gpu.dot_async %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} :
tensor<128x64xf16, #shared> * tensor<64x64xf16, #shared1> -> tensor<128x64xf32, #mma>
tt.return
}
}