Files
Andi Drebes b24709a1ec feat(compiler): Batching: Hoist non-batchable operands produced by pure ops
The batching pass only creates a batched version of a batchable
operation if all of its non-batchable operands are defined out ouf the
outermost loop the iterating over the values of the batchable operand.

This change also allows for operations to be batched if the
non-batachable operands are generated by operations, which are pure
and thus hoistable out of the outermost loop.
2023-03-24 11:06:51 +01:00

988 lines
36 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <llvm/ADT/STLExtras.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <concretelang/Interfaces/BatchableInterface.h>
#include <concretelang/Transforms/Passes.h>
namespace mlir {
namespace concretelang {
/// Checks if the value `v` is defined outside of the `loop` or a pure
/// operation that can be safely replicated ouside the loop (i.e., all
/// of its operands are also recursively either defined outside of the
/// loop or pure).
static bool isHoistable(mlir::Value v, mlir::scf::ForOp loop) {
mlir::Operation *op = v.getDefiningOp();
return loop.isDefinedOutsideOfLoop(v) ||
(op && mlir::isPure(op) && op->getNumResults() == 1 &&
llvm::all_of(op->getOperands(), [&](mlir::Value operand) {
return isHoistable(operand, loop);
}));
}
/// Checks if `forOp` has constant bounds and a constant step.
static bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb = nullptr,
int64_t *iub = nullptr, int64_t *istep = nullptr) {
mlir::Operation *lbOp = forOp.getLowerBound().getDefiningOp();
mlir::Operation *ubOp = forOp.getUpperBound().getDefiningOp();
mlir::Operation *stepOp = forOp.getStep().getDefiningOp();
if (!lbOp || !ubOp || !stepOp)
return false;
mlir::arith::ConstantIndexOp lb =
llvm::dyn_cast<mlir::arith::ConstantIndexOp>(lbOp);
mlir::arith::ConstantIndexOp ub =
llvm::dyn_cast<mlir::arith::ConstantIndexOp>(ubOp);
mlir::arith::ConstantIndexOp step =
llvm::dyn_cast<mlir::arith::ConstantIndexOp>(stepOp);
if (lb && ub && step) {
if (ilb)
*ilb = lb.value();
if (iub)
*iub = ub.value();
if (istep)
*istep = step.value();
return true;
}
return false;
}
/// Checks if `forOp` is a loop with a lower bound of 0, a constant
/// upper bound and a constant step of 1
static bool isStaticNormalizedLoop(mlir::scf::ForOp forOp) {
int64_t lb;
int64_t step;
if (isStaticLoop(forOp, &lb, nullptr, &step))
return (lb == 0 && step == 1);
return false;
}
/// Returns an `OpFoldResult` with an `IntegerAttr` value if `v` is
/// produced by a constant, otherwise an `OpFoldResult` containing `v`
/// itself.
static mlir::OpFoldResult getValueAsOpFoldResult(mlir::Value v) {
if (mlir::arith::ConstantOp cstOp =
dyn_cast_or_null<mlir::arith::ConstantOp>(v.getDefiningOp())) {
return cstOp.getValue();
}
return v;
}
/// Assumes that `v` is a constant index operation and returns the
/// constant value as an `int64_t`.
static int64_t getConstantIndexValue(mlir::Value v) {
assert(v.getDefiningOp() &&
llvm::isa<mlir::arith::ConstantIndexOp>(*v.getDefiningOp()));
return llvm::dyn_cast<mlir::arith::ConstantIndexOp>(*v.getDefiningOp())
.value();
}
/// Returns a `Value` from an `OpFoldResult`. If the `OpFoldResult` is
/// a already a value, the value is returned as is. Otherwise a
/// constant op is created using `builder`.
static mlir::Value getOpFoldResultAsValue(mlir::ImplicitLocOpBuilder &builder,
mlir::OpFoldResult v) {
if (v.is<mlir::Value>()) {
return v.dyn_cast<mlir::Value>();
} else {
return builder.create<mlir::arith::ConstantIndexOp>(
v.get<mlir::Attribute>().cast<mlir::IntegerAttr>().getInt());
}
}
/// Performs an arithmetic operation on `a` and `b`, where both values
/// can be any combination of `IntegerAttr` and `Value`.
template <typename ArithOp, typename ArithFunctor,
typename IsNeutralElementFunctor>
mlir::OpFoldResult opFoldExpr(mlir::ImplicitLocOpBuilder &builder,
mlir::OpFoldResult a, mlir::OpFoldResult b) {
static IsNeutralElementFunctor isNeutralElement;
auto exprValVal = [&](mlir::Value a, mlir::Value b) -> mlir::Value {
return builder.create<ArithOp>(a, b);
};
auto exprAttrVal = [&](mlir::IntegerAttr attr, mlir::Value v) -> mlir::Value {
mlir::Value cst =
builder.create<mlir::arith::ConstantIndexOp>(attr.getInt());
return exprValVal(cst, v);
};
auto exprValAttr = [&](mlir::Value v, mlir::IntegerAttr attr) -> mlir::Value {
mlir::Value cst =
builder.create<mlir::arith::ConstantIndexOp>(attr.getInt());
return exprValVal(v, cst);
};
auto exprAttrAttr = [&](mlir::IntegerAttr a,
mlir::IntegerAttr b) -> mlir::IntegerAttr {
static ArithFunctor f;
return builder.getIndexAttr(f(a.getInt(), b.getInt()));
};
if (a.is<mlir::Value>()) {
if (b.is<mlir::Value>()) {
return exprValVal(a.get<mlir::Value>(), b.get<mlir::Value>());
} else {
mlir::IntegerAttr bAttr =
b.get<mlir::Attribute>().cast<mlir::IntegerAttr>();
if (isNeutralElement(bAttr.getValue().getSExtValue())) {
return a;
} else {
return exprValAttr(a.get<mlir::Value>(), bAttr);
}
}
} else {
mlir::IntegerAttr aAttr =
a.get<mlir::Attribute>().cast<mlir::IntegerAttr>();
if (b.is<mlir::Value>()) {
return exprAttrVal(aAttr, b.get<mlir::Value>());
} else {
mlir::IntegerAttr bAttr =
b.get<mlir::Attribute>().cast<mlir::IntegerAttr>();
if (isNeutralElement(bAttr.getValue().getSExtValue()))
return a;
else
return exprAttrAttr(aAttr, bAttr);
}
}
}
/// Helper class whose call operator compares its argument to the
/// constant value `cst`.
template <typename T, const T cst> struct comparator {
bool operator()(const T &val) { return cst == val; }
};
/// Divides `a` by `b`, where both values can be any combination of
/// `IntegerAttr` and `Value`.
static mlir::OpFoldResult opFoldDiv(mlir::ImplicitLocOpBuilder &builder,
mlir::OpFoldResult a,
mlir::OpFoldResult b) {
return opFoldExpr<mlir::arith::DivSIOp, std::divides<int64_t>,
comparator<int64_t, 1>>(builder, a, b);
}
/// Subtracts `b` from `a`, where both values can be any combination
/// of `IntegerAttr` and `Value`.
static mlir::OpFoldResult opFoldSub(mlir::ImplicitLocOpBuilder &builder,
mlir::OpFoldResult a,
mlir::OpFoldResult b) {
return opFoldExpr<mlir::arith::SubIOp, std::minus<int64_t>,
comparator<int64_t, 0>>(builder, a, b);
}
/// Convenience class that holds all parameters of a loop
struct BoundsAndStep {
int64_t lb;
int64_t ub;
int64_t step;
BoundsAndStep operator+(const BoundsAndStep &other) {
return BoundsAndStep{lb + other.lb, ub + other.ub, step + other.step};
}
BoundsAndStep operator-(const BoundsAndStep &other) {
return BoundsAndStep{lb - other.lb, ub - other.ub, step - other.step};
}
BoundsAndStep operator*(const BoundsAndStep &other) {
return BoundsAndStep{lb * other.lb, ub * other.ub, step * other.step};
}
BoundsAndStep operator/(int64_t d) {
return BoundsAndStep{lb / d, ub / d, step / d};
}
};
/// Returns the lower bound, upper bound and step of the quasi-affine
/// expression `expr` on the the induction variable from a for
/// operation.
static std::optional<BoundsAndStep>
getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) {
// Base case: expression is the induction variable itself -> return
// loop bounds
if (expr == forOp.getInductionVar()) {
return BoundsAndStep{getConstantIndexValue(forOp.getLowerBound()),
getConstantIndexValue(forOp.getUpperBound()),
getConstantIndexValue(forOp.getStep())};
}
// Arithmetic expression
else if (mlir::Operation *op = expr.getDefiningOp()) {
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::SubIOp, mlir::arith::MulIOp,
mlir::arith::DivSIOp>(op)) {
std::optional<BoundsAndStep> lhs =
getBoundsOfQuasiAffineIVExpression(op->getOperand(0), forOp);
std::optional<BoundsAndStep> rhs =
getBoundsOfQuasiAffineIVExpression(op->getOperand(1), forOp);
if (!lhs.has_value() || !rhs.has_value())
return std::nullopt;
if (llvm::isa<mlir::arith::AddIOp>(op))
return *lhs + *rhs;
else if (llvm::isa<mlir::arith::SubIOp>(op))
return *lhs - *rhs;
else if (llvm::isa<mlir::arith::MulIOp>(op))
return (*lhs) * (*rhs);
else if (llvm::isa<mlir::arith::DivSIOp>(op)) {
assert(rhs->ub == rhs->lb && rhs->step == 0 &&
"Expression for divisor references IV");
int64_t rhsVal = rhs->ub;
assert(rhsVal != 0 && "Division by zero");
// If the step value of the subexpression is not a multiple of
// the divisor, there may be two iterations with the same
// value. Conservatively bail out.
if (lhs->step % rhsVal != 0)
return std::nullopt;
return *lhs / rhsVal;
}
}
// Base case: constant -> return constant value
else if (llvm::isa<mlir::arith::ConstantIndexOp>(expr.getDefiningOp())) {
mlir::arith::ConstantIndexOp cst =
llvm::dyn_cast<mlir::arith::ConstantIndexOp>(expr.getDefiningOp());
return BoundsAndStep{cst.value(), cst.value(), 0};
}
}
llvm_unreachable("Expression could not be evaluated statically");
}
/// Checks whether the expression `expr` is a quasi-affine expression
/// on a single induction variable. If an induction variable is
/// referenced, the owning for loop is returned in `*owningForOp`.
static bool isQuasiAffineIVExpression(mlir::Value expr,
mlir::scf::ForOp *owningForOp = nullptr) {
if (mlir::Operation *op = expr.getDefiningOp()) {
if (llvm::isa<mlir::arith::ConstantIndexOp>(op)) {
return true;
} else if (llvm::isa<mlir::arith::AddIOp, mlir::arith::SubIOp,
mlir::arith::MulIOp, mlir::arith::DivSIOp>(op)) {
mlir::scf::ForOp forLHS;
mlir::scf::ForOp forRHS;
if (!isQuasiAffineIVExpression(op->getOperand(0), &forLHS) ||
!isQuasiAffineIVExpression(op->getOperand(1), &forRHS)) {
return false;
} else {
// Check that appearances of IVs refer to the same IV
if (forLHS && forRHS && forLHS != forRHS)
return false;
}
// Assume that the expression is already canonicalized, so that
// IVs appear only in numerators and on one side of a
// multiplication subexpression
if ((llvm::isa<mlir::arith::MulIOp>(op) && forLHS && forRHS) ||
(llvm::isa<mlir::arith::DivSIOp>(op) && forRHS))
return false;
if (owningForOp != nullptr) {
if (forLHS)
*owningForOp = forLHS;
else if (forRHS)
*owningForOp = forRHS;
}
return true;
}
return false;
}
// Base case: Expression is an induction variable
else if (mlir::scf::ForOp forOp = scf::getForInductionVarOwner(expr)) {
if (owningForOp != nullptr)
*owningForOp = forOp;
return true;
}
return false;
}
/// Invokes `callback` for every subexpression of `expr` that is an
/// induction variable with the corresponding for operation as the
/// argument. Stops if the callback function returns `true`.
static void forEveryReferencedInductionVarBreakable(
mlir::Value expr, llvm::function_ref<bool(mlir::scf::ForOp)> callback) {
if (mlir::scf::ForOp forOp = scf::getForInductionVarOwner(expr)) {
callback(forOp);
} else {
if (expr.getDefiningOp()) {
for (mlir::Value operand : expr.getDefiningOp()->getOperands()) {
forEveryReferencedInductionVarBreakable(operand, callback);
}
}
}
}
/// Invokes `callback` for every subexpression of `expr` that is an
/// induction variable with the corresponding for operation as the
/// argument.
static void forEveryReferencedInductionVar(
mlir::Value expr, llvm::function_ref<void(mlir::scf::ForOp)> callback) {
forEveryReferencedInductionVarBreakable(expr,
[&](mlir::scf::ForOp forOp) -> bool {
callback(forOp);
return false;
});
}
/// Returns the loop associated to the first induction variable
/// encountered in a subexpression of `expr`.
static mlir::scf::ForOp findFirstReferencedInductionVar(mlir::Value expr) {
mlir::scf::ForOp ret;
forEveryReferencedInductionVarBreakable(expr,
[&](mlir::scf::ForOp forOp) -> bool {
ret = forOp;
return true;
});
return ret;
}
/// Checks if `expr` is a quasi affine expression on a single
/// induction variable, for which the increment of the induction
/// variable with the step of the associated for loop results in a
/// constant incrementation of when evaluating the expression.
///
/// E.g., this is true for the expression `i+1` for any constant step
/// size, since `((i+step)+1) - (i+1)` is constant. This is also true
/// for `(i+5)/7` for a step size that is a multiple of `7`, but false
/// for any other step size.
static bool
isQuasiAffineIVExpressionWithConstantStep(mlir::Value expr,
mlir::scf::ForOp *forOp = nullptr) {
mlir::scf::ForOp tmpForOp;
if (isQuasiAffineIVExpression(expr, &tmpForOp)) {
std::optional<BoundsAndStep> bas =
getBoundsOfQuasiAffineIVExpression(expr, tmpForOp);
if (bas.has_value()) {
if (forOp != nullptr)
*forOp = tmpForOp;
return true;
}
}
return false;
}
/// Hoists the pure operation producing the value `v` out of
/// `outermostFor` recursively. All newly created mappings are
/// collected in `mapping`.
static mlir::Value hoistPure(mlir::PatternRewriter &rewriter,
mlir::scf::ForOp outermostFor,
mlir::IRMapping &mapping, mlir::Value v) {
if (outermostFor.isDefinedOutsideOfLoop(v))
return v;
mlir::Operation *op = v.getDefiningOp();
assert(op && mlir::isPure(op) && op->getNumResults() == 1);
for (mlir::Value operand : op->getOperands()) {
if (!mapping.contains(operand))
mapping.map(operand, hoistPure(rewriter, outermostFor, mapping, operand));
}
rewriter.setInsertionPoint(outermostFor);
mlir::Operation *clonedOp = rewriter.clone(*op, mapping);
return clonedOp->getResult(0);
}
/// Hoists the pure operation producing the value `v` out of
/// `outermostFor` recursively.
static mlir::Value hoistPure(mlir::PatternRewriter &rewriter,
mlir::scf::ForOp outermostFor, mlir::Value v) {
mlir::IRMapping mapping;
return hoistPure(rewriter, outermostFor, mapping, v);
}
/// Hoists a an operation embedded into a loop nest that and that is
/// indexed using quasi-affine expressions of the loops' IVs as a
/// `tensor.extract_slice` outside of the outermost loop
template <typename EltWiseOp>
mlir::Value hoistIndexedOp(
mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor,
mlir::Value tensorizedOperands, EltWiseOp eltwiseOp,
llvm::function_ref<mlir::Value(
mlir::ImplicitLocOpBuilder &, mlir::Value,
llvm::ArrayRef<mlir::OpFoldResult>, llvm::ArrayRef<mlir::OpFoldResult>,
llvm::ArrayRef<mlir::OpFoldResult>, llvm::ArrayRef<bool>)>
tensorOpBuilder) {
llvm::SmallVector<mlir::OpFoldResult> offsets;
llvm::SmallVector<mlir::OpFoldResult> sizes;
llvm::SmallVector<mlir::OpFoldResult> strides;
llvm::SmallVector<bool> ivIndexedDims;
rewriter.setInsertionPoint(outermostFor);
mlir::ImplicitLocOpBuilder ilob(eltwiseOp.getLoc(), rewriter);
for (mlir::Value idx : eltwiseOp.getIndices()) {
mlir::scf::ForOp forOp;
bool isAffine = isQuasiAffineIVExpression(idx, &forOp);
if (isAffine && forOp) {
std::optional<BoundsAndStep> bas =
getBoundsOfQuasiAffineIVExpression(idx, forOp);
assert(bas.has_value());
assert(bas->step != 0);
offsets.push_back(rewriter.getIndexAttr(bas->lb));
sizes.push_back(rewriter.getIndexAttr((bas->ub - bas->lb) / bas->step));
strides.push_back(rewriter.getIndexAttr(bas->step));
ivIndexedDims.push_back(true);
} else if (isAffine || outermostFor.isDefinedOutsideOfLoop(idx)) {
offsets.push_back(getValueAsOpFoldResult(idx));
sizes.push_back(rewriter.getIndexAttr(1));
strides.push_back(rewriter.getIndexAttr(1));
ivIndexedDims.push_back(false);
}
}
return tensorOpBuilder(ilob, tensorizedOperands, offsets, sizes, strides,
ivIndexedDims);
}
/// Hoists a tensor.extract operation embedded into a loop nest as a
/// `tensor.extract_slice` outside of the outermost loop of the nest
static mlir::Value hoistExtractOp(mlir::PatternRewriter &rewriter,
mlir::scf::ForOp outermostFor,
mlir::tensor::ExtractOp extractOp) {
return hoistIndexedOp<mlir::tensor::ExtractOp>(
rewriter, outermostFor, extractOp.getTensor(), extractOp,
[](mlir::ImplicitLocOpBuilder &builder, mlir::Value tensorizedOperands,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> sizes,
llvm::ArrayRef<mlir::OpFoldResult> strides,
llvm::ArrayRef<bool> ivIndexedDims) -> mlir::Value {
mlir::tensor::ExtractSliceOp slice =
builder.create<mlir::tensor::ExtractSliceOp>(
tensorizedOperands, offsets, sizes, strides);
// The extract slice operation above preserves non-IV-indexed
// dimensions of the original extract operation as 1-sized
// dimensions, e.g., a `tensor.extract[cst, i, j, cst, k]`
// results in a slice with the shape `1xMxNx1xK` (where M, N
// and K are the maximum values for i, j and k, assuming a
// loop step of 1).
//
// If there is any non-IV-indexed dimension, add a collapse
// shape operation that collapses the 1-sized dimensions
// originating from non-IV-indexed dimensions of the extract
// operation into the IV-indexed dimensions. I.e., in the
// above example, produce a slice with the shape `MxNxK`
// rather than `1xMxNx1xK`.
if (llvm::all_of(ivIndexedDims, [](bool v) { return v; })) {
return slice;
} else {
llvm::SmallVector<mlir::ReassociationIndices> collapseGroups;
mlir::ReassociationIndices currCollapseGroup;
bool prefixDone = false;
for (auto i : llvm::enumerate(ivIndexedDims)) {
// If this is a non-IV-indexed dimension, accumulate
// dimension in the current group of collapsed dimensions
if (!i.value()) {
currCollapseGroup.push_back(i.index());
} else {
// If there were only non-IV-indexed dimensions before,
// add this first IV-indexed dimension to the current
// group of collapsed dimensions and try to accumulate
// with following, non-IV-indexed dimensions.
if (!prefixDone) {
currCollapseGroup.push_back(i.index());
prefixDone = true;
} else {
collapseGroups.push_back(currCollapseGroup);
currCollapseGroup = mlir::ReassociationIndices();
currCollapseGroup.push_back(i.index());
}
}
}
// Add last collapse group for trailing series of
// non-IV-indexed dimensions
if (!currCollapseGroup.empty())
collapseGroups.push_back(currCollapseGroup);
mlir::tensor::CollapseShapeOp cso =
builder.create<mlir::tensor::CollapseShapeOp>(slice,
collapseGroups);
return cso;
}
});
}
/// Hoists a tensor.insert operation embedded into a loop nest as a
/// tensor.insert_slice outside of the outermost loop of the nest
static mlir::Value hoistInsertOp(mlir::PatternRewriter &rewriter,
mlir::Value tensorizedOperands,
mlir::Value targetTensor,
mlir::scf::ForOp outermostFor,
mlir::tensor::InsertOp insertOp) {
return hoistIndexedOp<mlir::tensor::InsertOp>(
rewriter, outermostFor, targetTensor, insertOp,
[&](mlir::ImplicitLocOpBuilder &builder, mlir::Value targetTesor,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> sizes,
llvm::ArrayRef<mlir::OpFoldResult> strides,
llvm::ArrayRef<bool> ivIndexedDims) -> mlir::Value {
return builder.create<mlir::tensor::InsertSliceOp>(
tensorizedOperands, targetTesor, offsets, sizes, strides);
});
}
/// Pattern that replaces a batchable operation embedded into a loop
/// nest with the batched version of the operation, e.g.,
///
/// scf.for %i = c0 to %cN step %c1 {
/// scf.for %j = c0 to %cM step %c1 {
/// scf.for %k = c0 to %cK step %c1 {
/// %s = tensor.extract %T[%i, %j, %k]
/// %res = batchable_op %s
/// ...
/// }
/// }
/// }
///
/// is replaced with:
///
/// %batchedSlice = tensor.extract_slice
/// %T[%c0, %c0, %c0] [%cN, %cM, %cK] [%c1, %c1, %c1]
/// %flatSlice = tensor.collapse_shape %batchedSlice
/// %resTFlat = batchedOp %flatSlice
/// %resT = tensor.expand_shape %resTFlat
///
/// scf.for %i = c0 to %cN step %c1 {
/// scf.for %j = c0 to %cM step %c1 {
/// scf.for %k = c0 to %cK step %c1 {
/// %res = tensor.extract %resT[%i, %j, %k]
/// ...
/// }
/// }
/// }
///
/// Any index may be a quasi-affine expression on a single loop
/// induction variable, but the distance between the result for any
/// two successive values of the IV must be constant.
class BatchingPattern : public mlir::OpRewritePattern<mlir::func::FuncOp> {
public:
BatchingPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<mlir::func::FuncOp>(context) {}
mlir::LogicalResult
matchAndRewrite(mlir::func::FuncOp func,
mlir::PatternRewriter &rewriter) const override {
// Operation that will be hoisted out of the loop nest and
// replaced by the batched version of the operation
BatchableOpInterface targetOp;
// Extract operation producing the scalar operand of the batchable
// operation
mlir::tensor::ExtractOp targetExtractOp;
// Outermost for loop of the loop nest in which the batchable op
// is located
mlir::scf::ForOp outermostFor;
// Find a batchable op which is embedded into a loop nest
func.walk([&](BatchableOpInterface scalarOp) {
// Is producer an extract op?
auto extractOp = llvm::dyn_cast_or_null<mlir::tensor::ExtractOp>(
scalarOp.getBatchableOperand().get().getDefiningOp());
if (!extractOp)
return mlir::WalkResult::skip();
// Is extract op embedded into a loop?
if (!isa<mlir::scf::ForOp>(extractOp->getParentOp()))
return mlir::WalkResult::skip();
// Find outermost for loop whose IV is used as an index
mlir::scf::ForOp currOutermostFor;
for (mlir::Value idx : extractOp.getIndices()) {
forEveryReferencedInductionVar(idx, [&](mlir::scf::ForOp forOp) {
if (!currOutermostFor ||
forOp.getOperation()->isAncestor(currOutermostFor)) {
currOutermostFor = forOp;
}
});
}
if (!currOutermostFor)
return mlir::WalkResult::skip();
if (!currOutermostFor.isDefinedOutsideOfLoop(extractOp.getTensor()))
return mlir::WalkResult::skip();
// Now make sure that each index is either a quasi-affine
// expression on a single Loop IV, with a constant offset for
// all steps, a constant or defined above the outermost loop.
for (mlir::Value idx : extractOp.getIndices()) {
if (!currOutermostFor.isDefinedOutsideOfLoop(idx) &&
!(idx.getDefiningOp() &&
isa<mlir::arith::ConstantOp>(idx.getDefiningOp())) &&
!isQuasiAffineIVExpressionWithConstantStep(idx)) {
return mlir::WalkResult::skip();
}
}
// Verify that other args are defined outside the loop nest or
// hoistable
if (!llvm::all_of(scalarOp.getNonBatchableOperands(), [&](mlir::Value v) {
return isHoistable(v, currOutermostFor);
})) {
return mlir::WalkResult::skip();
}
// Make sure that there are only loops on the way from the
// outermost loop to the extract operation (i.e., loops are not
// embedded in other regions)
for (Operation *op = extractOp->getParentOp();
op != currOutermostFor->getParentOp(); op = op->getParentOp()) {
if (!isa<mlir::scf::ForOp>(op) ||
!isStaticLoop(llvm::dyn_cast<mlir::scf::ForOp>(op)))
return mlir::WalkResult::skip();
}
targetOp = scalarOp;
outermostFor = currOutermostFor;
targetExtractOp = extractOp;
return mlir::WalkResult::interrupt();
});
if (!targetOp)
return mlir::failure();
mlir::Value slice = hoistExtractOp(rewriter, outermostFor, targetExtractOp);
mlir::RankedTensorType sliceType =
slice.getType().cast<mlir::RankedTensorType>();
mlir::Value flattenedSlice;
mlir::ReassociationIndices indices;
if (sliceType.getRank() == 1) {
flattenedSlice = slice;
} else {
// Flatten the tensor with the batched operands, so that they
// can be passed as a one-dimensional tensor to the batched
// operation
for (int64_t i = 0; i < sliceType.getRank(); i++)
indices.push_back(i);
flattenedSlice = rewriter.create<mlir::tensor::CollapseShapeOp>(
targetExtractOp.getLoc(), slice,
llvm::SmallVector<mlir::ReassociationIndices>{indices});
}
// Hoist all non-batchable operands
llvm::SmallVector<mlir::Value> hoistedNonBatchableOperands;
for (mlir::Value operand : targetOp.getNonBatchableOperands()) {
hoistedNonBatchableOperands.push_back(
hoistPure(rewriter, outermostFor, operand));
}
// Create the batched operation and pass flattened, batched
// operands
mlir::ImplicitLocOpBuilder ilob(targetExtractOp.getLoc(), rewriter);
mlir::Value batchedOpResult = targetOp.createBatchedOperation(
ilob, flattenedSlice, hoistedNonBatchableOperands);
mlir::Value expandedBatchResultTensor;
if (sliceType.getRank() == 1) {
expandedBatchResultTensor = batchedOpResult;
} else {
// Restore original shape of the batched operands for the result
// of the batched operation. Dimensions, result from indexing
// with non-loop-IVs are collapsed.
mlir::Type expandedBatchResultType = mlir::RankedTensorType::get(
sliceType.getShape(), batchedOpResult.getType()
.dyn_cast<mlir::RankedTensorType>()
.getElementType());
expandedBatchResultTensor = rewriter.create<mlir::tensor::ExpandShapeOp>(
targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult,
llvm::SmallVector<mlir::ReassociationIndices>{indices});
}
// Collect all loop IVs from the extract op. These will be used to
// index the batched result tensor within the loop for consumers
// of the batchable op
llvm::SmallVector<mlir::Value> shiftedLoopIVs;
ilob.setInsertionPoint(targetOp);
for (mlir::Value idx : targetExtractOp.getIndices()) {
mlir::scf::ForOp forOp = findFirstReferencedInductionVar(idx);
if (forOp) {
// Loop has either a lower bound that is not 0 or a non-unit
// step. Shift the index to match the shape of the batched
// results.
if (!isStaticNormalizedLoop(forOp)) {
idx = getOpFoldResultAsValue(
ilob,
opFoldDiv(
ilob,
opFoldSub(ilob,
getValueAsOpFoldResult(forOp.getInductionVar()),
getValueAsOpFoldResult(forOp.getLowerBound())),
getValueAsOpFoldResult(forOp.getStep())));
}
shiftedLoopIVs.push_back(idx);
}
}
rewriter.setInsertionPoint(targetOp);
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractOp>(
targetOp, expandedBatchResultTensor, shiftedLoopIVs);
return mlir::success();
}
};
/// Cleanup pattern that replaces a perfect loop nest resulting from
/// repeated application of `BatchingPattern` that only contains a
/// `tensor.extract`, a `tensor.insert` and a `scf.yield` op in the
/// innermost loop, interleaved with side-effect-free operations, with
/// `tensor.extract_slice` and `tensor.insert_slice` ops. E.g.,
///
/// %res0 = scf.for %i = c0 to %cN step %c1 iter_args(%arg0 = %T1) {
/// %res1 = scf.for %j = c0 to %cM step %c1 iter_args(%arg1 = %arg0) {
/// %res2 = scf.for %k = c0 to %cK step %c1 iter_args(%arg2 = %arg1) {
/// %s = tensor.extract %T2[%i, %j, %k]
/// %TRes = tensor.insert %s into %arg2
/// scf.yield %arg2
/// }
/// scf.yield %res2
/// }
/// scf.yield %res1
/// }
///
/// is replaced with:
///
/// %tmp = tensor.extract_slice %T2
/// %res0 = tensor.insert_slice %tmp into %T1
///
/// Any index may be a quasi-affine expression on a single loop
/// induction variable, but the distance between the result for any
/// two successive values of the IV must be constant.
class CleanupPattern : public mlir::OpRewritePattern<mlir::func::FuncOp> {
public:
CleanupPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<mlir::func::FuncOp>(context) {}
mlir::LogicalResult
matchAndRewrite(mlir::func::FuncOp func,
mlir::PatternRewriter &rewriter) const override {
mlir::scf::ForOp outermostFor;
mlir::tensor::ExtractOp extractOp;
mlir::tensor::InsertOp insertOp;
func.walk([&](mlir::tensor::ExtractOp currExtractOp) {
// First check that the extract op is embedded in a for loop
mlir::scf::ForOp innermostFor =
llvm::dyn_cast<mlir::scf::ForOp>(currExtractOp->getParentOp());
if (!innermostFor)
return mlir::WalkResult::skip();
// Next, check find a chain of the 3 relevant operations:
//
// %s = tensor.extract %T[...]
// %U' = tensor.insert %s into %U[...]
// scf.yield %U'
//
// All other operations must be side-effect-free.
mlir::Block &body = innermostFor.getRegion().front();
mlir::scf::YieldOp yield =
llvm::dyn_cast<mlir::scf::YieldOp>(body.getTerminator());
if (yield.getOperands().size() != 1)
return mlir::WalkResult::skip();
mlir::Operation *yieldOperandProducer =
yield.getOperand(0).getDefiningOp();
if (!yieldOperandProducer)
return mlir::WalkResult::skip();
mlir::tensor::InsertOp currInsertOp =
llvm::dyn_cast<mlir::tensor::InsertOp>(yieldOperandProducer);
if (!currInsertOp ||
currInsertOp.getScalar() != currExtractOp.getResult())
return mlir::WalkResult::skip();
if (!llvm::all_of(body, [&](mlir::Operation &op) {
return isMemoryEffectFree(&op);
})) {
return mlir::WalkResult::skip();
}
// Now check that all IVs used for indexation are from perfectly
// nested loops down to the parent loop of the extract op and
// identify the outermost loop of the nest
mlir::scf::ForOp currOutermostFor;
if (currExtractOp.getIndices().size() !=
currInsertOp.getIndices().size()) {
return mlir::WalkResult::skip();
}
// Find outermost for loop whose IV is used as an index and make
// sure that IVs are used for the same indexes of the extract
// and insert operations
for (auto it :
llvm::zip(currExtractOp.getIndices(), currInsertOp.getIndices())) {
mlir::Value extractIdx = std::get<0>(it);
mlir::Value insertIdx = std::get<1>(it);
mlir::scf::ForOp extractForOp;
mlir::scf::ForOp insertForOp;
if (!isQuasiAffineIVExpressionWithConstantStep(extractIdx,
&extractForOp) ||
!isQuasiAffineIVExpressionWithConstantStep(insertIdx,
&insertForOp)) {
return mlir::WalkResult::skip();
}
if (insertForOp && extractForOp &&
insertForOp.getOperation() == extractForOp.getOperation()) {
if (!currOutermostFor ||
extractForOp.getOperation()->isAncestor(currOutermostFor)) {
currOutermostFor = extractForOp;
}
}
}
if (!currOutermostFor)
return mlir::WalkResult::skip();
// Check that the nest from the outermost to the innermost loop
// is perfect and forwards the result of the innermost loop to
// the outermost one
for (mlir::Operation *forOp = innermostFor.getOperation()->getParentOp();
forOp != currOutermostFor.getOperation()->getParentOp();
forOp = forOp->getParentOp()) {
mlir::scf::ForOp currentFor = llvm::dyn_cast<mlir::scf::ForOp>(forOp);
// Parent is not a for loop
if (!currentFor)
return mlir::WalkResult::skip();
// Body must have exactly two ops: a for loop and a yield
mlir::Block &body = currentFor.getRegion().front();
if (body.begin() != std::prev(body.end(), 2))
return mlir::WalkResult::skip();
mlir::scf::ForOp childFor =
llvm::dyn_cast<mlir::scf::ForOp>(*body.begin());
mlir::scf::YieldOp yield =
llvm::dyn_cast<mlir::scf::YieldOp>(*std::next(body.begin()));
// Check that result of the for loop is forwarded
if (!childFor || !yield || yield.getOperands().size() != 1 ||
childFor.getResults().size() != 1 ||
yield.getOperand(0) != childFor.getResult(0))
return mlir::WalkResult::skip();
}
outermostFor = currOutermostFor;
extractOp = currExtractOp;
insertOp = currInsertOp;
return mlir::WalkResult::interrupt();
});
if (!outermostFor)
return mlir::failure();
// Outermost for loop must produce exactly one result
if (outermostFor.getInitArgs().size() != 1)
return mlir::failure();
// Original tensor that is carried through the loops
mlir::Value initialTensor = outermostFor.getInitArgs().front();
mlir::Value slice = hoistExtractOp(rewriter, outermostFor, extractOp);
mlir::Value insertedSlice =
hoistInsertOp(rewriter, slice, initialTensor, outermostFor, insertOp);
// Replace the entire loop nest with the result of the insert
// slice op. Since this is a perfect loop nest with the innermost
// body only producing the tensor elements, there cannot be any
// other operations that produces results or that has side
// effects.
rewriter.replaceOp(outermostFor, {insertedSlice});
return mlir::success();
}
};
class BatchingPass : public BatchingBase<BatchingPass> {
public:
void runOnOperation() override {
mlir::Operation *op = getOperation();
mlir::RewritePatternSet patterns(op->getContext());
patterns.add<BatchingPattern, CleanupPattern>(op->getContext());
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed())
this->signalPassFailure();
}
};
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createBatchingPass() {
return std::make_unique<BatchingPass>();
}
} // namespace concretelang
} // namespace mlir