mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Disallow the CombineSelectMaskedLoad pattern if conditions of select and broadcast are different (#1170)
This commit is contained in:
@@ -69,6 +69,7 @@ public:
|
||||
|
||||
mlir::Value trueValue = selectOp.getTrueValue();
|
||||
mlir::Value falseValue = selectOp.getFalseValue();
|
||||
mlir::Value condSelect = selectOp.getCondition();
|
||||
|
||||
auto *loadOpCandidate = trueValue.getDefiningOp();
|
||||
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
|
||||
@@ -85,6 +86,10 @@ public:
|
||||
if (!broadcastOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto broadcastCond = broadcastOp.src();
|
||||
if (broadcastCond != condSelect)
|
||||
return mlir::failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
|
||||
@@ -64,19 +64,25 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
|
||||
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
|
||||
%0 = select %cond0, %dummy_load, %false_val : tensor<8xf32>
|
||||
|
||||
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
|
||||
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
|
||||
%1 = select %cond0, %real_load0, %false_val : tensor<8xf32>
|
||||
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
// Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized.
|
||||
%cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1>
|
||||
%real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%2 = select %cond1, %real_load1, %false_val : tensor<8xf32>
|
||||
|
||||
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
||||
|
||||
Reference in New Issue
Block a user