[FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305)

Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
This commit is contained in:
peterbell10
2023-04-13 01:37:39 +00:00
committed by GitHub
parent 5b9119117b
commit e152183570
23 changed files with 822 additions and 606 deletions

View File

@@ -217,7 +217,11 @@ tt.func @alloc(%A : !tt.ptr<f16>) {
tt.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: scratch offset = 0, size = 512
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
%b = "tt.reduce" (%cst0) ({
^bb0(%arg0: f16, %arg1: f16):
%add = arith.addf %arg0, %arg1 : f16
tt.reduce.return %add : f16
}) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
tt.return
// CHECK-NEXT: size = 512
}

View File

@@ -79,7 +79,11 @@ tt.func @scratch() {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
%2 = "tt.reduce" (%1) ({
^bb0(%arg1: f16, %arg2: f16):
%add = arith.addf %arg1, %arg2 : f16
tt.reduce.return %add : f16
}) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
tt.return
}