[BACKEND] Remove ttg.cmp and ttg.select and replace by arith op (#2526)

Now that the bug related to attribute is fixed in MLIR we can use arith
ops for cmp and select ops.
This commit is contained in:
Thomas Raoux
2023-10-23 19:35:46 -07:00
committed by GitHub
parent b0c166b9e3
commit cba7abd682
20 changed files with 247 additions and 369 deletions

View File

@@ -79,8 +79,7 @@ tt.func public @select_op(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32>
// CHECK: %[[splat:.*]] = tt.splat %arg2 : (i1) -> tensor<128xi1, #blocked>
// CHECK-NEXT: %{{.*}} = "triton_gpu.select"(%[[splat]], %{{.*}}, %{{.*}}) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
// CHECK: %{{.*}} = arith.select %arg2, %{{.*}}, %{{.*}} : tensor<128xf32, #blocked>
%4 = arith.select %arg2, %cst, %3 : tensor<128xf32>
%5 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>

View File

@@ -200,7 +200,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
%10 = arith.cmpi "slt", %4, %9 : tensor<64xi32, #blocked>
// load op has a vector width = 1 due to the %mask's alignment
// CHECK: ld.global.b32
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>