mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Relax patterns to move sink broadcast and hoist convert (#2331)
Improve patterns that sync broadcast to reduce the arithmetic density and also hoist convert on top of expand_dims to do less work. This address comments in https://github.com/openai/triton/pull/2274
This commit is contained in:
@@ -116,18 +116,20 @@ struct MoveBroadcastAfterElementwisePattern
|
||||
|
||||
auto operands = op->getOperands();
|
||||
bool seenBroadcast = false;
|
||||
Type srcType;
|
||||
ArrayRef<int64_t> srcShape;
|
||||
for (auto operand : operands) {
|
||||
auto definingOp = operand.getDefiningOp();
|
||||
if (!definingOp) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto getSrcShape = [](triton::BroadcastOp b) {
|
||||
return b.getSrc().getType().cast<RankedTensorType>().getShape();
|
||||
};
|
||||
if (auto broadcastOp = llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
|
||||
if (!seenBroadcast) {
|
||||
seenBroadcast = true;
|
||||
srcType = broadcastOp.getSrc().getType();
|
||||
} else if (srcType != broadcastOp.getSrc().getType()) {
|
||||
srcShape = getSrcShape(broadcastOp);
|
||||
} else if (srcShape != getSrcShape(broadcastOp)) {
|
||||
// If the broadcast have different types we cannot re-order.
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
@@ -929,7 +929,7 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
|
||||
|
||||
auto isExtOrBroadcastOp = [](Operation *op) {
|
||||
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
|
||||
triton::BroadcastOp>(op);
|
||||
triton::BroadcastOp, triton::ExpandDimsOp>(op);
|
||||
};
|
||||
// 1. Take a backward slice of all the tensor dependencies.
|
||||
SetVector<Value> slice;
|
||||
@@ -950,8 +950,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
|
||||
if (isExtOrBroadcastOp(op)) {
|
||||
SetVector<Value> tempSlice;
|
||||
DenseMap<Value, Attribute> tempLayout;
|
||||
std::optional<Attribute> srcEncoding = inferSrcEncoding(op, layout[v]);
|
||||
if (!srcEncoding)
|
||||
return;
|
||||
LogicalResult result = getRematerializableSlice(
|
||||
op->getOperand(0), layout[v], tempSlice, tempLayout);
|
||||
op->getOperand(0), *srcEncoding, tempSlice, tempLayout);
|
||||
// If we can rematerialize the rest of the ext slice we can ignore this
|
||||
// ext as it won't need a convert.
|
||||
if (result.succeeded()) {
|
||||
@@ -969,13 +972,16 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
|
||||
|
||||
if (extOrBroadcatOp == nullptr)
|
||||
return;
|
||||
std::optional<Attribute> srcEncoding =
|
||||
inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]);
|
||||
if (!srcEncoding)
|
||||
return;
|
||||
// Move the convert before the ext op and rewrite the slice.
|
||||
OpBuilder builder(extOrBroadcatOp);
|
||||
auto tensorType =
|
||||
extOrBroadcatOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto newType =
|
||||
RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(),
|
||||
layout[extOrBroadcatOp->getResult(0)]);
|
||||
auto newType = RankedTensorType::get(
|
||||
tensorType.getShape(), tensorType.getElementType(), *srcEncoding);
|
||||
auto newConvertOp = builder.create<ConvertLayoutOp>(
|
||||
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
|
||||
IRMapping mapping;
|
||||
|
||||
@@ -53,3 +53,15 @@ tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tenso
|
||||
|
||||
tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_broadcast_mix_type_op_pattern
|
||||
tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) {
|
||||
// CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32>
|
||||
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : (tensor<128x1xf32>) -> tensor<128x128xf32>
|
||||
%broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32>
|
||||
%broadcast1 = tt.splat %arg1 : (f32) -> tensor<128x128xf32>
|
||||
%cond = tt.broadcast %arg3 : (tensor<128x1xi1>) -> tensor<128x128xi1>
|
||||
%sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32>
|
||||
|
||||
tt.return %sel : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
@@ -1189,10 +1189,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
|
||||
// CHECK-LABEL: reduce_cvt2
|
||||
// Match the reduction
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: tt.reduce
|
||||
// CHECK-SAME: axis = 1
|
||||
// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK: tt.expand_dims
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: tt.return
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
|
||||
Reference in New Issue
Block a user