[BACKEND] Disallow the CombineSelectMaskedLoad pattern if conditions of select and broadcast are different (#1170)

This commit is contained in:
Keren Zhou
2023-02-09 18:03:22 -05:00
committed by GitHub
parent 0cbe368fe5
commit c61c8a123f
2 changed files with 16 additions and 5 deletions

View File

@@ -69,6 +69,7 @@ public:
mlir::Value trueValue = selectOp.getTrueValue();
mlir::Value falseValue = selectOp.getFalseValue();
mlir::Value condSelect = selectOp.getCondition();
auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
@@ -85,6 +86,10 @@ public:
if (!broadcastOp)
return mlir::failure();
auto broadcastCond = broadcastOp.src();
if (broadcastCond != condSelect)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());

View File

@@ -64,19 +64,25 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
}
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
%0 = select %cond0, %dummy_load, %false_val : tensor<8xf32>
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
%1 = select %cond0, %real_load0, %false_val : tensor<8xf32>
return %0, %1 : tensor<8xf32>, tensor<8xf32>
// Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized.
%cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1>
%real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%2 = select %cond1, %real_load1, %false_val : tensor<8xf32>
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_broadcast_constant_pattern