[OPTIMIZER] Add pass to move broadcasts after elementwise operations (#1811)

This adds a pass that tries to reduce the shape of tensor arguments to
element-wise operations by moving splat and broadcast operations later
in the graph. So, for example say we have:

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset  + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp2 = 0.017453292519943295
    tmp3 = tmp1 * tmp2
    tmp4 = tl.sin(tmp3)
    tl.store(out_ptr0 + (x0), tmp4, None)
```

Today this results in duplicate `sin` calls:
```
    %27 = llvm.fmul %26, %3  : f32
    %28 = llvm.call @__nv_sinf(%27) : (f32) -> f32
    %29 = llvm.call @__nv_sinf(%27) : (f32) -> f32
```

The duplicate `llvm.fmul` calls are eliminated via CSE, but `llvm.call`
doesn't get CSE'd because it might be impure.

After this change, the sin is done on a scalar value in the triton IR
and splatted at the very end, so no duplicate calculation happens within
a thread.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
peterbell10
2023-07-10 19:44:38 +01:00
committed by GitHub
parent ef947dac31
commit e3d9478d31
10 changed files with 342 additions and 6 deletions

View File

@@ -31,7 +31,8 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -44,7 +45,8 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -58,7 +60,8 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
}
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -73,6 +76,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
// TODO: Add verifier
}
// FIXME: Not elementwise because scalars are not supported
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
@@ -99,6 +103,7 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
//
def TT_AddPtrOp : TT_Op<"addptr",
[Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
@@ -458,7 +463,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
//
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
TT_Op<mnemonic,
traits # [SameOperandsAndResultEncoding,
traits # [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let description = [{

View File

@@ -8,6 +8,8 @@ namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
std::unique_ptr<Pass> createReorderBroadcastPass();
std::unique_ptr<Pass>
createRewriteTensorPointerPass(int computeCapability = 80);

View File

@@ -19,6 +19,15 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
let dependentDialects = ["mlir::arith::ArithDialect"];
}
def TritonReorderBroadcast : Pass</*cli-arg*/"triton-reorder-broadcast", /*Op*/"mlir::ModuleOp"> {
let summary = "Moves broadcast and splat after elementwise operations";
let description = [{
elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
}];
let constructor = "mlir::triton::createReorderBroadcastPass()";
let dependentDialects = ["mlir::triton::TritonDialect"];
}
def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
let description = [{

View File

@@ -641,6 +641,31 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
splat.getOperand());
return mlir::success();
}
// expand_dims(broadcast) -> broadcast(expand_dims)
//
// On it's own this doesn't do much, but consider
// broadcast(expand_dims(broadcast))
// -> broadcast(broadcast(expand_dims))
// -> broadcast(expand_dims)
if (auto broadcast = dyn_cast<triton::BroadcastOp>(definingOp)) {
auto src = broadcast.getSrc();
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
auto elemTy = srcTy.getElementType();
auto srcShape = srcTy.getShape();
llvm::SmallVector<int64_t, 4> newExpandShape(srcShape.begin(),
srcShape.end());
newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1);
auto newExpandTy = RankedTensorType::get(newExpandShape, elemTy);
auto newExpand = rewriter.create<triton::ExpandDimsOp>(
op.getLoc(), newExpandTy, src, op.getAxis());
auto newBroadcast = rewriter.create<triton::BroadcastOp>(
broadcast.getLoc(), op.getType(), newExpand.getResult());
rewriter.replaceOp(op, {newBroadcast.getResult()});
return mlir::success();
}
return mlir::failure();
}

View File

@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen)
add_mlir_dialect_library(TritonTransforms
Combine.cpp
ReorderBroadcast.cpp
RewriteTensorPointer.cpp
DEPENDS

View File

@@ -0,0 +1,242 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include <memory>
namespace mlir {
#define GEN_PASS_DEF_TRITONREORDERBROADCAST
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter,
Operation *op, ValueRange newOperands,
TypeRange newTypes) {
OperationState newElementwiseState(op->getLoc(), op->getName());
newElementwiseState.addOperands(newOperands);
newElementwiseState.addTypes(newTypes);
newElementwiseState.addAttributes(op->getAttrs());
return rewriter.create(newElementwiseState);
}
bool isSplat(Operation *op) {
if (auto splatOp = llvm::dyn_cast<triton::SplatOp>(op)) {
return true;
}
DenseElementsAttr constAttr;
return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat());
}
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
struct MoveSplatAfterElementwisePattern
: public mlir::OpTraitRewritePattern<mlir::OpTrait::Elementwise> {
MoveSplatAfterElementwisePattern(mlir::MLIRContext *context)
: OpTraitRewritePattern(context) {}
mlir::LogicalResult match(Operation *op) const override {
if (!isMemoryEffectFree(op)) {
return mlir::failure();
}
for (auto operand : op->getOperands()) {
auto definingOp = operand.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (!isSplat(definingOp)) {
return mlir::failure();
}
}
return mlir::success(op->getNumOperands() > 0);
}
void rewrite(Operation *op, mlir::PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto operands = op->getOperands();
llvm::SmallVector<Value, 4> scalarOperands(operands.size());
for (unsigned iOp = 0; iOp < operands.size(); ++iOp) {
auto definingOp = operands[iOp].getDefiningOp();
DenseElementsAttr constAttr;
if (auto splatOp = llvm::dyn_cast<triton::SplatOp>(definingOp)) {
scalarOperands[iOp] = splatOp.getSrc();
} else if (matchPattern(definingOp, m_Constant(&constAttr)) &&
constAttr.isSplat()) {
auto value = constAttr.getSplatValue<Attribute>();
scalarOperands[iOp] = arith::ConstantOp::materialize(
rewriter, value, constAttr.getElementType(), loc);
} else {
llvm_unreachable("Expected a splat");
}
}
auto resultTypes = op->getResultTypes();
llvm::SmallVector<Type, 4> scalarResultTys;
for (auto resultTy : resultTypes) {
auto elemTy = resultTy.dyn_cast<TensorType>().getElementType();
scalarResultTys.push_back(elemTy);
}
auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands,
scalarResultTys);
for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) {
auto newResult = rewriter.create<triton::SplatOp>(loc, resultTypes[iRes],
newOp->getResult(iRes));
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
}
}
};
// elementwise(broadcast(a)) => broadcast(elementwise(a))
// This also generalizes to multiple arguments when the rest are splat-like
// Not handled: multiple broadcasted arguments
struct MoveBroadcastAfterElementwisePattern
: public mlir::OpTraitRewritePattern<mlir::OpTrait::Elementwise> {
MoveBroadcastAfterElementwisePattern(mlir::MLIRContext *context)
: OpTraitRewritePattern(context) {}
mlir::LogicalResult match(Operation *op) const override {
if (!isMemoryEffectFree(op)) {
return mlir::failure();
}
auto operands = op->getOperands();
bool seenBroadcast = false;
for (auto operand : operands) {
auto definingOp = operand.getDefiningOp();
if (!definingOp) {
return mlir::failure();
}
if (auto broadcastOp = llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
if (seenBroadcast) {
// Only support one broadcasted argument for now
return mlir::failure();
}
seenBroadcast = true;
} else if (!isSplat(definingOp)) {
// Not splat or broadcast
return mlir::failure();
}
}
return mlir::success(seenBroadcast);
}
void rewrite(Operation *op, mlir::PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
// Find broadcast op
auto operands = op->getOperands();
triton::BroadcastOp broadcastOp;
for (auto operand : operands) {
if (broadcastOp = operand.getDefiningOp<triton::BroadcastOp>()) {
break;
}
}
auto src = broadcastOp.getSrc();
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto srcEncoding = srcTy.getEncoding();
// Reshape operands to match srcShape
llvm::SmallVector<Value, 4> newOperands;
for (auto operand : operands) {
auto definingOp = operand.getDefiningOp();
if (llvm::isa<triton::BroadcastOp>(definingOp)) {
newOperands.push_back(src);
continue;
}
auto elemTy =
operand.getType().dyn_cast<RankedTensorType>().getElementType();
auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding);
if (auto splatOp = llvm::dyn_cast<triton::SplatOp>(definingOp)) {
auto newSplat =
rewriter.create<triton::SplatOp>(loc, newTy, splatOp.getSrc());
newOperands.push_back(newSplat);
continue;
}
DenseElementsAttr constAttr;
if (matchPattern(definingOp, m_Constant(&constAttr)) &&
constAttr.isSplat()) {
auto scalarValue = constAttr.getSplatValue<Attribute>();
auto splatValue = SplatElementsAttr::get(newTy, scalarValue);
auto newConstant =
rewriter.create<arith::ConstantOp>(loc, newTy, splatValue);
newOperands.push_back(newConstant);
continue;
}
llvm_unreachable("Expected broadcast or splat");
}
// Reshape results to match srcShape
llvm::SmallVector<Type, 4> newResultTypes;
auto resultTypes = op->getResultTypes();
for (auto resultTy : resultTypes) {
auto elemTy = resultTy.dyn_cast<RankedTensorType>().getElementType();
newResultTypes.push_back(
RankedTensorType::get(srcShape, elemTy, srcEncoding));
}
// Create new op and broadcast results
auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands,
newResultTypes);
for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) {
auto newResult = rewriter.create<triton::BroadcastOp>(
loc, resultTypes[iRes], newOp->getResult(iRes));
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
}
}
};
template <typename OpType>
class CanonicalizePattern : public mlir::OpRewritePattern<OpType> {
public:
explicit CanonicalizePattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<OpType>(context) {}
mlir::LogicalResult
matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override {
return OpType::canonicalize(op, rewriter);
}
};
class ReorderBroadcastPass
: public mlir::impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::ModuleOp m = getOperation();
patterns.add<CanonicalizePattern<triton::BroadcastOp>>(context);
patterns.add<CanonicalizePattern<triton::ExpandDimsOp>>(context);
// elementwise(broadcast(a)) => broadcast(elementwise(a))
patterns.add<MoveBroadcastAfterElementwisePattern>(context);
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
patterns.add<MoveSplatAfterElementwisePattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};
} // namespace
std::unique_ptr<mlir::Pass> mlir::triton::createReorderBroadcastPass() {
return std::make_unique<ReorderBroadcastPass>();
}

View File

@@ -1561,6 +1561,10 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());
})
.def("add_reorder_broadcast_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createReorderBroadcastPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(

View File

@@ -55,6 +55,7 @@ def optimize_ttir(mod, arch):
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_reorder_broadcast_pass()
pm.add_cse_pass()
pm.add_licm_pass()
pm.add_symbol_dce_pass()

View File

@@ -152,12 +152,18 @@ tt.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>
}
// CHECK-LABEL: @test_canonicalize_expand_dims
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>) -> (tensor<1x8xf32>) {
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>) -> (tensor<1x8xf32>, tensor<8x8xf32>) {
%splat = tt.splat %arg0 : (tensor<f32>) -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg0 : (tensor<f32>) -> tensor<1x8xf32>
%ed = tt.expand_dims %splat {axis = 0 : i32} : (tensor<8xf32>) -> tensor<1x8xf32>
tt.return %ed : tensor<1x8xf32>
// CHECK-NEXT: %[[ed2:.*]] = tt.expand_dims %arg1 {axis = 0 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[ed2]] : (tensor<1x1xf32>) -> tensor<8x8xf32>
%bc = tt.broadcast %arg1 : (tensor<1xf32>) -> tensor<8xf32>
%ed2 = tt.expand_dims %bc {axis = 0 : i32} : (tensor<8xf32>) -> tensor<1x8xf32>
%bc2 = tt.broadcast %ed2 : (tensor<1x8xf32>) -> tensor<8x8xf32>
tt.return %ed, %bc2 : tensor<1x8xf32>, tensor<8x8xf32>
}

View File

@@ -0,0 +1,40 @@
// RUN: triton-opt %s -split-input-file -triton-reorder-broadcast | FileCheck %s
// CHECK-LABEL: @test_splat_elementwise_pattern
tt.func @test_splat_elementwise_pattern(%arg0: f32) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr<f32>>) {
// CHECK-DAG: %[[a:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : i64
%c1 = arith.constant 1 : i64
%a = arith.constant dense<1.0> : tensor<128x128xf32>
// CHECK-DAG: %[[add:.*]] = arith.addf %arg0, %[[a]] : f32
// CHECK-NEXT: %[[splat:.*]] = tt.splat %[[add]] : (f32) -> tensor<128x128xf32>
%b = tt.splat %arg0 : (f32) -> tensor<128x128xf32>
%add = arith.addf %a, %b : tensor<128x128xf32>
// CHECK-NEXT: %[[ptr:.*]] = tt.int_to_ptr %[[c1]] : i64 -> !tt.ptr<f32>
// CHECK-NEXT: %{{.*}} = tt.splat %[[ptr]] : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
%c1_t = tt.splat %c1 : (i64) -> tensor<128x128xi64>
%ptr = tt.int_to_ptr %c1_t : tensor<128x128xi64> -> tensor<128x128x!tt.ptr<f32>>
tt.return %add, %ptr : tensor<128x128xf32>, tensor<128x128x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_broadcast_elementwise_pattern
tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor<128x128xf32>, tensor<128x32xf32>) {
// CHECK: %[[one:.*]] = arith.constant dense<1.000000e+00> : tensor<128x1xf32>
// CHECK-NEXT: %[[abs:.*]] = math.absf %arg0 : tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[abs]] : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32>
%abs = math.absf %broadcast : tensor<128x128xf32>
// CHECK-NEXT: %[[add:.*]] = arith.addf %arg0, %[[one]] : tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[add]] : (tensor<128x1xf32>) -> tensor<128x32xf32>
%broadcast2 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x32xf32>
%one = arith.constant dense<1.0> : tensor<128x32xf32>
%add = arith.addf %one, %broadcast2 : tensor<128x32xf32>
tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32>
}