mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Allow reduce with sliced 3D layout as input (#2480)
This commit is contained in:
@@ -161,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
|
||||
1;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
|
||||
|
||||
@@ -1469,20 +1469,19 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: copyitem
|
||||
// CHECK: st.shared.b8
|
||||
// CHECK: ld.shared.b8
|
||||
// CHECK-NOT: st.shared.b1
|
||||
// CHECK-NOT: ld.shared.b1
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @copyitem() attributes {noinline = false} {
|
||||
%cst = arith.constant dense<true> : tensor<4x1xi1, #blocked>
|
||||
// CHECK-LABEL: reduce_slice
|
||||
// CHECK-NOT: st.shared
|
||||
// CHECK-NOT: ld.shared
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}>
|
||||
#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @reduce_slice() attributes {noinline = false} {
|
||||
%cst = arith.constant dense<true> : tensor<4x1xi1, #sliced2>
|
||||
%0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg0: i1, %arg1: i1):
|
||||
%1 = arith.ori %arg0, %arg1 : i1
|
||||
tt.reduce.return %1 : i1
|
||||
}) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
}) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user