mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)
Fixes #1285 This changes `tt.reduce` to replace `redOp` by a region containing arbitrary code. For example, `tl.sum` is now lowered as: ```mlir %res = "tt.reduce"(%arg0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32> ``` Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example `argmin` gets lowered as: ```mlir %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32> %9:2 = "tt.reduce"(%6, %8) ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): %14 = arith.cmpf olt, %arg4, %arg6 : f32 %15 = arith.cmpf ogt, %arg4, %arg6 : f32 %16 = arith.cmpi slt, %arg5, %arg7 : i32 %17 = arith.select %16, %arg5, %arg7 : i32 %18 = arith.select %15, %arg7, %17 : i32 %19 = arith.select %14, %arg5, %18 : i32 %20 = arith.cmpf olt, %arg4, %arg6 : f32 %21 = arith.select %20, %arg4, %arg6 : f32 tt.reduce.return %21, %19 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) ```
This commit is contained in:
@@ -217,7 +217,11 @@ tt.func @alloc(%A : !tt.ptr<f16>) {
|
||||
tt.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: scratch offset = 0, size = 512
|
||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
%b = "tt.reduce" (%cst0) ({
|
||||
^bb0(%arg0: f16, %arg1: f16):
|
||||
%add = arith.addf %arg0, %arg1 : f16
|
||||
tt.reduce.return %add : f16
|
||||
}) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
@@ -79,7 +79,11 @@ tt.func @scratch() {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
%2 = "tt.reduce" (%1) ({
|
||||
^bb0(%arg1: f16, %arg2: f16):
|
||||
%add = arith.addf %arg1, %arg2 : f16
|
||||
tt.reduce.return %add : f16
|
||||
}) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user