[BACKEND] Relax patterns to move sink broadcast and hoist convert (#2331)

Improve patterns that sync broadcast to reduce the arithmetic density
and also hoist convert on top of expand_dims to do less work.

This address comments in https://github.com/openai/triton/pull/2274
This commit is contained in:
Thomas Raoux
2023-09-18 15:08:19 -07:00
committed by GitHub
parent 73dae775df
commit 3a848e2729
4 changed files with 32 additions and 10 deletions

View File

@@ -116,18 +116,20 @@ struct MoveBroadcastAfterElementwisePattern
auto operands = op->getOperands();
bool seenBroadcast = false;
Type srcType;
ArrayRef<int64_t> srcShape;
for (auto operand : operands) {
auto definingOp = operand.getDefiningOp();
if (!definingOp) {
return mlir::failure();
}
auto getSrcShape = [](triton::BroadcastOp b) {
return b.getSrc().getType().cast<RankedTensorType>().getShape();
};
if (auto broadcastOp = llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
if (!seenBroadcast) {
seenBroadcast = true;
srcType = broadcastOp.getSrc().getType();
} else if (srcType != broadcastOp.getSrc().getType()) {
srcShape = getSrcShape(broadcastOp);
} else if (srcShape != getSrcShape(broadcastOp)) {
// If the broadcast have different types we cannot re-order.
return mlir::failure();
}

View File

@@ -929,7 +929,7 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
auto isExtOrBroadcastOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
triton::BroadcastOp>(op);
triton::BroadcastOp, triton::ExpandDimsOp>(op);
};
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
@@ -950,8 +950,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
if (isExtOrBroadcastOp(op)) {
SetVector<Value> tempSlice;
DenseMap<Value, Attribute> tempLayout;
std::optional<Attribute> srcEncoding = inferSrcEncoding(op, layout[v]);
if (!srcEncoding)
return;
LogicalResult result = getRematerializableSlice(
op->getOperand(0), layout[v], tempSlice, tempLayout);
op->getOperand(0), *srcEncoding, tempSlice, tempLayout);
// If we can rematerialize the rest of the ext slice we can ignore this
// ext as it won't need a convert.
if (result.succeeded()) {
@@ -969,13 +972,16 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
if (extOrBroadcatOp == nullptr)
return;
std::optional<Attribute> srcEncoding =
inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]);
if (!srcEncoding)
return;
// Move the convert before the ext op and rewrite the slice.
OpBuilder builder(extOrBroadcatOp);
auto tensorType =
extOrBroadcatOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType =
RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(),
layout[extOrBroadcatOp->getResult(0)]);
auto newType = RankedTensorType::get(
tensorType.getShape(), tensorType.getElementType(), *srcEncoding);
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
IRMapping mapping;

View File

@@ -53,3 +53,15 @@ tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tenso
tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK-LABEL: @test_broadcast_mix_type_op_pattern
tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) {
// CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast1 = tt.splat %arg1 : (f32) -> tensor<128x128xf32>
%cond = tt.broadcast %arg3 : (tensor<128x1xi1>) -> tensor<128x128xi1>
%sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32>
tt.return %sel : tensor<128x128xf32>
}

View File

@@ -1189,10 +1189,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: reduce_cvt2
// Match the reduction
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.reduce
// CHECK-SAME: axis = 1
// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>>
// CHECK: triton_gpu.convert_layout
// CHECK: tt.expand_dims
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.return
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>