mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Remove ttg.cmp and ttg.select and replace by arith op (#2526)
Now that the bug related to attribute is fixed in MLIR we can use arith ops for cmp and select ops.
This commit is contained in:
@@ -79,8 +79,7 @@ tt.func public @select_op(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
|
||||
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32>
|
||||
|
||||
// CHECK: %[[splat:.*]] = tt.splat %arg2 : (i1) -> tensor<128xi1, #blocked>
|
||||
// CHECK-NEXT: %{{.*}} = "triton_gpu.select"(%[[splat]], %{{.*}}, %{{.*}}) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
|
||||
// CHECK: %{{.*}} = arith.select %arg2, %{{.*}}, %{{.*}} : tensor<128xf32, #blocked>
|
||||
%4 = arith.select %arg2, %cst, %3 : tensor<128xf32>
|
||||
|
||||
%5 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
|
||||
@@ -200,7 +200,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
|
||||
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
|
||||
%10 = arith.cmpi "slt", %4, %9 : tensor<64xi32, #blocked>
|
||||
// load op has a vector width = 1 due to the %mask's alignment
|
||||
// CHECK: ld.global.b32
|
||||
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
|
||||
|
||||
Reference in New Issue
Block a user