mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Add pass to move broadcasts after elementwise operations (#1811)
This adds a pass that tries to reduce the shape of tensor arguments to
element-wise operations by moving splat and broadcast operations later
in the graph. So, for example say we have:
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp2 = 0.017453292519943295
tmp3 = tmp1 * tmp2
tmp4 = tl.sin(tmp3)
tl.store(out_ptr0 + (x0), tmp4, None)
```
Today this results in duplicate `sin` calls:
```
%27 = llvm.fmul %26, %3 : f32
%28 = llvm.call @__nv_sinf(%27) : (f32) -> f32
%29 = llvm.call @__nv_sinf(%27) : (f32) -> f32
```
The duplicate `llvm.fmul` calls are eliminated via CSE, but `llvm.call`
doesn't get CSE'd because it might be impure.
After this change, the sin is done on a scalar value in the triton IR
and splatted at the very end, so no duplicate calculation happens within
a thread.
---------
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -31,7 +31,8 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
Pure,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
@@ -44,7 +45,8 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
Pure,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
@@ -58,7 +60,8 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
}
|
||||
|
||||
// arith.bitcast doesn't support pointers
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
Pure,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
@@ -73,6 +76,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
// TODO: Add verifier
|
||||
}
|
||||
|
||||
// FIXME: Not elementwise because scalars are not supported
|
||||
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
Pure,
|
||||
@@ -99,6 +103,7 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
//
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[Pure,
|
||||
Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
@@ -458,7 +463,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
|
||||
//
|
||||
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
|
||||
TT_Op<mnemonic,
|
||||
traits # [SameOperandsAndResultEncoding,
|
||||
traits # [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
|
||||
let description = [{
|
||||
|
||||
@@ -8,6 +8,8 @@ namespace triton {
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
std::unique_ptr<Pass> createReorderBroadcastPass();
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createRewriteTensorPointerPass(int computeCapability = 80);
|
||||
|
||||
|
||||
@@ -19,6 +19,15 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
||||
let dependentDialects = ["mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonReorderBroadcast : Pass</*cli-arg*/"triton-reorder-broadcast", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "Moves broadcast and splat after elementwise operations";
|
||||
let description = [{
|
||||
elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
|
||||
}];
|
||||
let constructor = "mlir::triton::createReorderBroadcastPass()";
|
||||
let dependentDialects = ["mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
|
||||
let description = [{
|
||||
|
||||
@@ -641,6 +641,31 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
|
||||
splat.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
// expand_dims(broadcast) -> broadcast(expand_dims)
|
||||
//
|
||||
// On it's own this doesn't do much, but consider
|
||||
// broadcast(expand_dims(broadcast))
|
||||
// -> broadcast(broadcast(expand_dims))
|
||||
// -> broadcast(expand_dims)
|
||||
if (auto broadcast = dyn_cast<triton::BroadcastOp>(definingOp)) {
|
||||
auto src = broadcast.getSrc();
|
||||
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
|
||||
auto elemTy = srcTy.getElementType();
|
||||
auto srcShape = srcTy.getShape();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> newExpandShape(srcShape.begin(),
|
||||
srcShape.end());
|
||||
newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1);
|
||||
auto newExpandTy = RankedTensorType::get(newExpandShape, elemTy);
|
||||
|
||||
auto newExpand = rewriter.create<triton::ExpandDimsOp>(
|
||||
op.getLoc(), newExpandTy, src, op.getAxis());
|
||||
auto newBroadcast = rewriter.create<triton::BroadcastOp>(
|
||||
broadcast.getLoc(), op.getType(), newExpand.getResult());
|
||||
rewriter.replaceOp(op, {newBroadcast.getResult()});
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen)
|
||||
|
||||
add_mlir_dialect_library(TritonTransforms
|
||||
Combine.cpp
|
||||
ReorderBroadcast.cpp
|
||||
RewriteTensorPointer.cpp
|
||||
|
||||
DEPENDS
|
||||
|
||||
242
lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Normal file
242
lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Normal file
@@ -0,0 +1,242 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_TRITONREORDERBROADCAST
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter,
|
||||
Operation *op, ValueRange newOperands,
|
||||
TypeRange newTypes) {
|
||||
OperationState newElementwiseState(op->getLoc(), op->getName());
|
||||
newElementwiseState.addOperands(newOperands);
|
||||
newElementwiseState.addTypes(newTypes);
|
||||
newElementwiseState.addAttributes(op->getAttrs());
|
||||
return rewriter.create(newElementwiseState);
|
||||
}
|
||||
|
||||
bool isSplat(Operation *op) {
|
||||
if (auto splatOp = llvm::dyn_cast<triton::SplatOp>(op)) {
|
||||
return true;
|
||||
}
|
||||
DenseElementsAttr constAttr;
|
||||
return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat());
|
||||
}
|
||||
|
||||
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
|
||||
struct MoveSplatAfterElementwisePattern
|
||||
: public mlir::OpTraitRewritePattern<mlir::OpTrait::Elementwise> {
|
||||
|
||||
MoveSplatAfterElementwisePattern(mlir::MLIRContext *context)
|
||||
: OpTraitRewritePattern(context) {}
|
||||
|
||||
mlir::LogicalResult match(Operation *op) const override {
|
||||
if (!isMemoryEffectFree(op)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (!isSplat(definingOp)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success(op->getNumOperands() > 0);
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, mlir::PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto operands = op->getOperands();
|
||||
|
||||
llvm::SmallVector<Value, 4> scalarOperands(operands.size());
|
||||
for (unsigned iOp = 0; iOp < operands.size(); ++iOp) {
|
||||
auto definingOp = operands[iOp].getDefiningOp();
|
||||
|
||||
DenseElementsAttr constAttr;
|
||||
if (auto splatOp = llvm::dyn_cast<triton::SplatOp>(definingOp)) {
|
||||
scalarOperands[iOp] = splatOp.getSrc();
|
||||
} else if (matchPattern(definingOp, m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
auto value = constAttr.getSplatValue<Attribute>();
|
||||
scalarOperands[iOp] = arith::ConstantOp::materialize(
|
||||
rewriter, value, constAttr.getElementType(), loc);
|
||||
} else {
|
||||
llvm_unreachable("Expected a splat");
|
||||
}
|
||||
}
|
||||
|
||||
auto resultTypes = op->getResultTypes();
|
||||
llvm::SmallVector<Type, 4> scalarResultTys;
|
||||
for (auto resultTy : resultTypes) {
|
||||
auto elemTy = resultTy.dyn_cast<TensorType>().getElementType();
|
||||
scalarResultTys.push_back(elemTy);
|
||||
}
|
||||
|
||||
auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands,
|
||||
scalarResultTys);
|
||||
|
||||
for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) {
|
||||
auto newResult = rewriter.create<triton::SplatOp>(loc, resultTypes[iRes],
|
||||
newOp->getResult(iRes));
|
||||
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// elementwise(broadcast(a)) => broadcast(elementwise(a))
|
||||
// This also generalizes to multiple arguments when the rest are splat-like
|
||||
// Not handled: multiple broadcasted arguments
|
||||
struct MoveBroadcastAfterElementwisePattern
|
||||
: public mlir::OpTraitRewritePattern<mlir::OpTrait::Elementwise> {
|
||||
|
||||
MoveBroadcastAfterElementwisePattern(mlir::MLIRContext *context)
|
||||
: OpTraitRewritePattern(context) {}
|
||||
|
||||
mlir::LogicalResult match(Operation *op) const override {
|
||||
if (!isMemoryEffectFree(op)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto operands = op->getOperands();
|
||||
bool seenBroadcast = false;
|
||||
for (auto operand : operands) {
|
||||
auto definingOp = operand.getDefiningOp();
|
||||
if (!definingOp) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (auto broadcastOp = llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
|
||||
if (seenBroadcast) {
|
||||
// Only support one broadcasted argument for now
|
||||
return mlir::failure();
|
||||
}
|
||||
seenBroadcast = true;
|
||||
} else if (!isSplat(definingOp)) {
|
||||
// Not splat or broadcast
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success(seenBroadcast);
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, mlir::PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Find broadcast op
|
||||
auto operands = op->getOperands();
|
||||
triton::BroadcastOp broadcastOp;
|
||||
for (auto operand : operands) {
|
||||
if (broadcastOp = operand.getDefiningOp<triton::BroadcastOp>()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto src = broadcastOp.getSrc();
|
||||
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
|
||||
// Reshape operands to match srcShape
|
||||
llvm::SmallVector<Value, 4> newOperands;
|
||||
for (auto operand : operands) {
|
||||
auto definingOp = operand.getDefiningOp();
|
||||
if (llvm::isa<triton::BroadcastOp>(definingOp)) {
|
||||
newOperands.push_back(src);
|
||||
continue;
|
||||
}
|
||||
auto elemTy =
|
||||
operand.getType().dyn_cast<RankedTensorType>().getElementType();
|
||||
auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding);
|
||||
if (auto splatOp = llvm::dyn_cast<triton::SplatOp>(definingOp)) {
|
||||
auto newSplat =
|
||||
rewriter.create<triton::SplatOp>(loc, newTy, splatOp.getSrc());
|
||||
newOperands.push_back(newSplat);
|
||||
continue;
|
||||
}
|
||||
DenseElementsAttr constAttr;
|
||||
if (matchPattern(definingOp, m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
auto scalarValue = constAttr.getSplatValue<Attribute>();
|
||||
auto splatValue = SplatElementsAttr::get(newTy, scalarValue);
|
||||
auto newConstant =
|
||||
rewriter.create<arith::ConstantOp>(loc, newTy, splatValue);
|
||||
newOperands.push_back(newConstant);
|
||||
continue;
|
||||
}
|
||||
llvm_unreachable("Expected broadcast or splat");
|
||||
}
|
||||
|
||||
// Reshape results to match srcShape
|
||||
llvm::SmallVector<Type, 4> newResultTypes;
|
||||
auto resultTypes = op->getResultTypes();
|
||||
for (auto resultTy : resultTypes) {
|
||||
auto elemTy = resultTy.dyn_cast<RankedTensorType>().getElementType();
|
||||
newResultTypes.push_back(
|
||||
RankedTensorType::get(srcShape, elemTy, srcEncoding));
|
||||
}
|
||||
|
||||
// Create new op and broadcast results
|
||||
auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands,
|
||||
newResultTypes);
|
||||
for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) {
|
||||
auto newResult = rewriter.create<triton::BroadcastOp>(
|
||||
loc, resultTypes[iRes], newOp->getResult(iRes));
|
||||
rewriter.replaceAllUsesWith(op->getResult(iRes), newResult);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpType>
|
||||
class CanonicalizePattern : public mlir::OpRewritePattern<OpType> {
|
||||
public:
|
||||
explicit CanonicalizePattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<OpType>(context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override {
|
||||
return OpType::canonicalize(op, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
class ReorderBroadcastPass
|
||||
: public mlir::impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
mlir::MLIRContext *context = &getContext();
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
mlir::ModuleOp m = getOperation();
|
||||
|
||||
patterns.add<CanonicalizePattern<triton::BroadcastOp>>(context);
|
||||
patterns.add<CanonicalizePattern<triton::ExpandDimsOp>>(context);
|
||||
// elementwise(broadcast(a)) => broadcast(elementwise(a))
|
||||
patterns.add<MoveBroadcastAfterElementwisePattern>(context);
|
||||
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
|
||||
patterns.add<MoveSplatAfterElementwisePattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> mlir::triton::createReorderBroadcastPass() {
|
||||
return std::make_unique<ReorderBroadcastPass>();
|
||||
}
|
||||
@@ -1561,6 +1561,10 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createCombineOpsPass());
|
||||
})
|
||||
.def("add_reorder_broadcast_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createReorderBroadcastPass());
|
||||
})
|
||||
.def("add_rewrite_tensor_pointer_pass",
|
||||
[](mlir::PassManager &self, int computeCapability) {
|
||||
self.addPass(mlir::triton::createRewriteTensorPointerPass(
|
||||
|
||||
@@ -55,6 +55,7 @@ def optimize_ttir(mod, arch):
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_reorder_broadcast_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
|
||||
@@ -152,12 +152,18 @@ tt.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_expand_dims
|
||||
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>) -> (tensor<1x8xf32>) {
|
||||
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>) -> (tensor<1x8xf32>, tensor<8x8xf32>) {
|
||||
%splat = tt.splat %arg0 : (tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK: %{{.*}} = tt.splat %arg0 : (tensor<f32>) -> tensor<1x8xf32>
|
||||
%ed = tt.expand_dims %splat {axis = 0 : i32} : (tensor<8xf32>) -> tensor<1x8xf32>
|
||||
|
||||
tt.return %ed : tensor<1x8xf32>
|
||||
// CHECK-NEXT: %[[ed2:.*]] = tt.expand_dims %arg1 {axis = 0 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
|
||||
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[ed2]] : (tensor<1x1xf32>) -> tensor<8x8xf32>
|
||||
%bc = tt.broadcast %arg1 : (tensor<1xf32>) -> tensor<8xf32>
|
||||
%ed2 = tt.expand_dims %bc {axis = 0 : i32} : (tensor<8xf32>) -> tensor<1x8xf32>
|
||||
%bc2 = tt.broadcast %ed2 : (tensor<1x8xf32>) -> tensor<8x8xf32>
|
||||
|
||||
tt.return %ed, %bc2 : tensor<1x8xf32>, tensor<8x8xf32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
40
test/Triton/reorder-broadcast.mlir
Normal file
40
test/Triton/reorder-broadcast.mlir
Normal file
@@ -0,0 +1,40 @@
|
||||
// RUN: triton-opt %s -split-input-file -triton-reorder-broadcast | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_splat_elementwise_pattern
|
||||
tt.func @test_splat_elementwise_pattern(%arg0: f32) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr<f32>>) {
|
||||
// CHECK-DAG: %[[a:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : i64
|
||||
%c1 = arith.constant 1 : i64
|
||||
%a = arith.constant dense<1.0> : tensor<128x128xf32>
|
||||
|
||||
// CHECK-DAG: %[[add:.*]] = arith.addf %arg0, %[[a]] : f32
|
||||
// CHECK-NEXT: %[[splat:.*]] = tt.splat %[[add]] : (f32) -> tensor<128x128xf32>
|
||||
%b = tt.splat %arg0 : (f32) -> tensor<128x128xf32>
|
||||
%add = arith.addf %a, %b : tensor<128x128xf32>
|
||||
|
||||
|
||||
// CHECK-NEXT: %[[ptr:.*]] = tt.int_to_ptr %[[c1]] : i64 -> !tt.ptr<f32>
|
||||
// CHECK-NEXT: %{{.*}} = tt.splat %[[ptr]] : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
%c1_t = tt.splat %c1 : (i64) -> tensor<128x128xi64>
|
||||
%ptr = tt.int_to_ptr %c1_t : tensor<128x128xi64> -> tensor<128x128x!tt.ptr<f32>>
|
||||
|
||||
tt.return %add, %ptr : tensor<128x128xf32>, tensor<128x128x!tt.ptr<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_broadcast_elementwise_pattern
|
||||
tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor<128x128xf32>, tensor<128x32xf32>) {
|
||||
// CHECK: %[[one:.*]] = arith.constant dense<1.000000e+00> : tensor<128x1xf32>
|
||||
|
||||
// CHECK-NEXT: %[[abs:.*]] = math.absf %arg0 : tensor<128x1xf32>
|
||||
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[abs]] : (tensor<128x1xf32>) -> tensor<128x128xf32>
|
||||
%broadcast = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32>
|
||||
%abs = math.absf %broadcast : tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[add:.*]] = arith.addf %arg0, %[[one]] : tensor<128x1xf32>
|
||||
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[add]] : (tensor<128x1xf32>) -> tensor<128x32xf32>
|
||||
%broadcast2 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x32xf32>
|
||||
%one = arith.constant dense<1.0> : tensor<128x32xf32>
|
||||
%add = arith.addf %one, %broadcast2 : tensor<128x32xf32>
|
||||
|
||||
tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32>
|
||||
}
|
||||
Reference in New Issue
Block a user