[BACKEND] Allow reduce with sliced 3D layout as input (#2480)

This commit is contained in:
Zahi Moudallal
2023-10-10 15:19:11 -07:00
committed by GitHub
parent 5812d970a8
commit 4749072fbd
2 changed files with 13 additions and 12 deletions

View File

@@ -161,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
auto srcLayout = getSrcLayout();
auto srcShape = getSrcShape();
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {

View File

@@ -1469,20 +1469,19 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
// -----
// CHECK-LABEL: copyitem
// CHECK: st.shared.b8
// CHECK: ld.shared.b8
// CHECK-NOT: st.shared.b1
// CHECK-NOT: ld.shared.b1
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @copyitem() attributes {noinline = false} {
%cst = arith.constant dense<true> : tensor<4x1xi1, #blocked>
// CHECK-LABEL: reduce_slice
// CHECK-NOT: st.shared
// CHECK-NOT: ld.shared
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}>
#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @reduce_slice() attributes {noinline = false} {
%cst = arith.constant dense<true> : tensor<4x1xi1, #sliced2>
%0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({
^bb0(%arg0: i1, %arg1: i1):
%1 = arith.ori %arg0, %arg1 : i1
tt.reduce.return %1 : i1
}) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>>
tt.return
}
}