[BACKEND] Fix multiple bugs in WGMMA (#2457)

Fix dependencies in wgmma_wait op to prevent the scheduler from moving
it past the uses of wgmma accumulator. We need to explicitly represent
the dependency between the wait and the accumulator uses otherwise LLVM
is free to re-order those.
This allows us to remove a workaround to prevent the re-ordering. We can
also remove the wait op added in the loop during pipelining.

Also fix the descritpor calcuation for wgmma, we should calculate the
same descriptor for the whole warpgroup.
Added a workaround for a bug that was exposed by different timing due to
those changes. We shouldn't insert operations between the loop and
async_wait or we may have race conditions.
This commit is contained in:
Thomas Raoux
2023-10-06 17:59:28 -07:00
committed by GitHub
parent ded79e87ee
commit a7061e19b2
12 changed files with 137 additions and 85 deletions

View File

@@ -175,6 +175,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-LABEL: @dot_reg_operand_A
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: tensor<64x64xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%opA = triton_gpu.convert_layout %a : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>

View File

@@ -30,3 +30,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2
tt.return
}
}
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} {
tt.func @wgmma_wait(%in: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) {
// CHECK: // wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63
// CHECK: wgmma.wait_group.sync.aligned 0;
%out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} :
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.return
}
}

View File

@@ -52,7 +52,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
triton_nvidia_gpu.mbarrier_wait %33, %arg23 : <i64, 3>
// CHECK: triton_nvidia_gpu.fence_async_shared
%34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma>
triton_nvidia_gpu.dot_wait {pendings = 1 : i32}
%35 = tt.advance %arg11, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%36 = tt.advance %arg12, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%37 = arith.addi %arg19, %c128_i32 : i32
@@ -88,10 +87,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%64 = arith.ori %62, %63 : i1
scf.yield %34, %35, %36, %47, %49, %s_48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1
}
scf.if %10 {
triton_nvidia_gpu.dot_wait {pendings = 0 : i32}
}
%31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>
%31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1>
triton_nvidia_gpu.store_async %8, %32 : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #shared1>
triton_gpu.async_bulk_commit_group
@@ -158,7 +155,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
triton_nvidia_gpu.mbarrier_wait %33, %arg23 : <i64, 3>
// CHECK: triton_nvidia_gpu.fence_async_shared
%34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma>
triton_nvidia_gpu.dot_wait {pendings = 1 : i32}
%35 = tt.advance %arg11, [%c0_i32, %c128_i32] : <tensor<128x128xf16, #blocked>, 1>
%36 = tt.advance %arg12, [%c128_i32, %c0_i32] : <tensor<128x128xf16, #blocked>, 1>
%37 = arith.addi %arg19, %c128_i32 : i32
@@ -192,10 +188,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%64 = arith.ori %62, %63 : i1
scf.yield %34, %35, %36, %47, %49, %48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, !tt.ptr<tensor<128x128xf16, #blocked>, 1>, i32, i32, i32, i32, i1, i1
}
scf.if %10 {
triton_nvidia_gpu.dot_wait {pendings = 0 : i32}
}
%31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>
%31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1>
triton_nvidia_gpu.store_async %8, %32 : !tt.ptr<tensor<128x128xf16, #blocked>, 1>, tensor<128x128xf16, #shared1>
triton_gpu.async_bulk_commit_group

View File

@@ -22,9 +22,9 @@
// CHECK: triton_gpu.extract_slice
// CHECK: triton_gpu.extract_slice
// CHECK: triton_nvidia_gpu.dot_async
// CHECK: triton_nvidia_gpu.dot_wait
// CHECK: triton_nvidia_gpu.consumer_release
// CHECK: scf.yield
// CHECK: triton_nvidia_gpu.dot_wait
// CHECK: async_agent = dense<1> : vector<1xi32>
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>