feat(compiler): Batching: Add pattern folding operations on tensors of constants

This adds a new pattern to the batching pass that folds operations on
tensors of constants into new tensors of constants. E.g.,

  %cst = arith.constant dense<...> : tensor<Nxi9>
  %res = scf.for %i = %c0 to %cN {
    %cst_i9 = tensor.extract %cst[%i]
    %cst_i64 = arith.extui %cst_i9 : i64
    ...
  }

becomes:

  %cst = arith.constant dense<...> : tensor<Nxi64>
  %res = scf.for %i = %c0 to %cN {
    %cst_i64 = tensor.extract %cst[%i]
    ...
  }

The pattern only works for static loops, indexes that are quasi-affine
expressions on single loop induction variables with a constant step
size across iterations and foldable operations that have a single
result.
This commit is contained in:
Andi Drebes
2023-06-01 12:14:59 +02:00
committed by Antoniu Pop
parent 3516ae7682
commit 38e14446d6

View File

@@ -6,6 +6,7 @@
#include <functional>
#include <limits>
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/TypeSwitch.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
@@ -215,6 +216,16 @@ sortLoopsInnermostToOutermost(ContainerTy &forOps) {
return res;
}
// Sorts the `scf.for` loops from `forOps` from the outermost to the
// innermost loop. The loops must be embedded one into another.
template <typename ContainerTy>
SmallVector<mlir::scf::ForOp>
sortLoopsOutermostToInnermost(ContainerTy &forOps) {
SmallVector<mlir::scf::ForOp> nest = sortLoopsInnermostToOutermost(forOps);
std::reverse(nest.begin(), nest.end());
return nest;
}
// Takes a set of loops `forOps` and finds the longest sequence of
// perfectly nested loops starting with innermost loops of `forOps`,
// in which the predicate `parentChildPredicate` holds for all loops
@@ -1988,6 +1999,378 @@ public:
}
};
// Folds the operation `op` by recursively folding all
// producers. Occurrences of `arg` are replaced with `argVal`.
// All encountered operations must produce a single result.
static mlir::Attribute fold(mlir::Operation *op, mlir::Value arg,
mlir::Attribute argVal) {
assert(op->getNumResults() == 1);
// Check if `arg` needs to be replaced with `argVal`
if (op->getResult(0) == arg)
return argVal;
// Constants are just folded to their value attributes
if (mlir::arith::ConstantOp cstOp =
llvm::dyn_cast<mlir::arith::ConstantOp>(op)) {
return cstOp.getValue();
}
// Recursively fold all producers and collect the folding results
// for each operand
llvm::SmallVector<mlir::Attribute> foldedOperands;
for (mlir::OpOperand &operand : op->getOpOperands()) {
mlir::Operation *producer = operand.get().getDefiningOp();
assert(producer);
foldedOperands.push_back(fold(producer, arg, argVal));
}
// Invoke the folder for this operation
llvm::SmallVector<mlir::OpFoldResult> res;
mlir::LogicalResult foldRes = op->fold(foldedOperands, res);
assert(foldRes.succeeded());
assert(res.size() == 1);
mlir::Attribute resAttr = res[0].dyn_cast<mlir::Attribute>();
assert(resAttr);
return resAttr;
}
// Folding pattern that collapses operations on constant dense tensors
// into a new constant. E.g.,
//
// %cst = arith.constant dense<...> : tensor<Nxi9>
// %res = scf.for %i = %c0 to %cN {
// %cst_i9 = tensor.extract %cst[%i]
// %cst_i64 = arith.extui %cst_i9 : i64
// ...
// }
//
// becomes:
//
// %cst = arith.constant dense<...> : tensor<Nxi64>
// %res = scf.for %i = %c0 to %cN {
// %cst_i64 = tensor.extract %cst[%i]
// ...
// }
class ConstantDenseFoldingPattern
: public mlir::OpRewritePattern<mlir::func::FuncOp> {
protected:
// Checks if an operation is foldable
static bool isFoldableOp(mlir::Operation *op) {
return op->getNumResults() == 1 && mlir::isPure(op) &&
llvm::TypeSwitch<mlir::Operation *, bool>(op)
.Case<mlir::arith::AddIOp, mlir::arith::ExtSIOp,
mlir::arith::ConstantOp>([](auto op) { return true; })
.Default([](auto op) { return false; });
}
// Checks if `v` can be calculated statically given that the values
// in `foldables` are static. The function recursively collects all
// intermediate values which have been found to be static in
// `foldables`.
static bool isFoldableValue(mlir::Value v,
llvm::DenseSet<mlir::Value> &foldables) {
if (foldables.contains(v))
return true;
mlir::Operation *op = v.getDefiningOp();
if (!op || !isFoldableOp(op))
return false;
if (llvm::all_of(op->getOperands(), [&](mlir::Value v) {
return isFoldableValue(v, foldables);
})) {
for (mlir::Value v : op->getOperands())
foldables.insert(v);
return true;
}
return false;
}
// Generates a flat index from a tensor with the shape `shape`
// indexed by `idx`
static int64_t linearizeIndex(llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> idx) {
int64_t flatIdx = 0;
int64_t mul = 1;
int64_t n = shape.size();
for (int64_t i = 0; i < n; i++) {
flatIdx += mul * idx[n - i - 1];
mul *= shape[n - i - 1];
}
return flatIdx;
};
public:
ConstantDenseFoldingPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<mlir::func::FuncOp>(context) {}
mlir::LogicalResult
matchAndRewrite(mlir::func::FuncOp func,
mlir::PatternRewriter &rewriter) const override {
mlir::tensor::ExtractOp extractOp;
mlir::Operation *targetOp = nullptr;
llvm::SmallVector<mlir::scf::ForOp> nest;
llvm::SmallVector<mlir::scf::ForOp> idxMap;
llvm::SmallVector<BoundsAndStep> idxBounds;
mlir::arith::ConstantOp cdo;
mlir::RankedTensorType constantType;
mlir::Type origElementType;
mlir::Type foldedElementType;
func.walk([&](mlir::tensor::ExtractOp currExtractOp) {
// Check that the extraction in on a value produced by an
// `arith.constant_dense` operation.
mlir::arith::ConstantOp currCdo =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
currExtractOp.getTensor().getDefiningOp());
if (!currCdo)
return mlir::WalkResult::skip();
if (!isSoleUser(currCdo.getResult(), currExtractOp))
return mlir::WalkResult::skip();
mlir::RankedTensorType currConstantType =
currExtractOp.getTensor()
.getType()
.dyn_cast<mlir::RankedTensorType>();
if (!currConstantType)
return mlir::WalkResult::skip();
// First check that the extract op is embedded in a for loop
mlir::scf::ForOp currInnermostFor =
llvm::dyn_cast<mlir::scf::ForOp>(currExtractOp->getParentOp());
if (!currInnermostFor)
return mlir::WalkResult::skip();
llvm::DenseSet<mlir::scf::ForOp> nestUnsorted;
llvm::SmallVector<mlir::scf::ForOp> currIdxMap;
llvm::SmallVector<BoundsAndStep> currIdxBounds;
// Make sure that the extract operation uses only quasi affine
// expressions on IVs, where each index uses at most a single
// IV.
for (mlir::Value idx : currExtractOp.getIndices()) {
mlir::scf::ForOp forOp;
BoundsAndStep bas;
if (!isQuasiAffineIVExpressionWithConstantStep(idx, &forOp, &bas))
return mlir::WalkResult::skip();
if (forOp)
nestUnsorted.insert(forOp);
currIdxBounds.push_back(bas);
currIdxMap.push_back(forOp);
}
llvm::DenseSet<mlir::Value> foldables;
foldables.insert(currExtractOp.getResult());
if (!currExtractOp.getResult().hasOneUse())
return mlir::WalkResult::skip();
mlir::Operation *firstUser =
currExtractOp.getResult().getUses().begin()->getOwner();
mlir::Operation *currOp = firstUser;
mlir::Operation *currTargetOp = nullptr;
mlir::Type currOrigElementType = currConstantType.getElementType();
mlir::Type currFoldedElementType = currOrigElementType;
// Walk down the def-use chain from the extract operaion until
// an operation is found that is not foldable
while (true) {
if (!isFoldableOp(currOp))
break;
if (currOp->getNumResults() != 1 || !currOp->getResult(0).hasOneUse() ||
currOp->getParentOp() != currInnermostFor.getOperation())
break;
if (!llvm::all_of(currOp->getOperands(), [&](mlir::Value v) {
return isFoldableValue(v, foldables);
}))
break;
currFoldedElementType = currOp->getResult(0).getType();
currTargetOp = currOp;
currOp = currOp->getUses().begin()->getOwner();
}
if (!currTargetOp)
return mlir::WalkResult::skip();
// Check constraints on the index space of the extract
// operation. I.e., if the type changes during teh folding,
// ensure that the index space covers the entire tensor and that
// there are no out-of-bounds accesses.
for (auto it : llvm::enumerate(currExtractOp.getIndices())) {
mlir::scf::ForOp forOp;
BoundsAndStep bas;
mlir::Value idx = it.value();
size_t i = it.index();
if (!isQuasiAffineIVExpressionWithConstantStep(idx, &forOp, &bas))
return mlir::WalkResult::skip();
// If the type changes by the folding, the entire tensor needs
// to be rewritten
if (currFoldedElementType != currOrigElementType) {
if (bas.lb != 0 || bas.ub != currConstantType.getDimSize(i) ||
bas.step != 1)
return mlir::WalkResult::skip();
}
// Otherwise, just make sure that there are no out-of-bounds
// accesses
else {
if (bas.ub - bas.step >= currConstantType.getDimSize(i))
return mlir::WalkResult::skip();
}
}
extractOp = currExtractOp;
targetOp = currTargetOp;
nest = sortLoopsOutermostToInnermost(nestUnsorted);
idxMap = std::move(currIdxMap);
idxBounds = std::move(currIdxBounds);
cdo = currCdo;
constantType = currConstantType;
origElementType = currOrigElementType;
foldedElementType = currFoldedElementType;
return mlir::WalkResult::interrupt();
});
if (!targetOp)
return mlir::failure();
// Original tensor of constants
auto denseVals = cdo.getValueAttr()
.cast<mlir::DenseElementsAttr>()
.getValues<mlir::Attribute>();
// Updated tensor of constants intialized with original values
SmallVector<mlir::Attribute> newDenseVals(denseVals.begin(),
denseVals.end());
mlir::SmallVector<int64_t> tripCounts = map(
nest, [](mlir::scf::ForOp forOp) { return getStaticTripCount(forOp); });
// Number of iterations already performed for each loop
mlir::SmallVector<int64_t> trips(nest.size(), 0);
// current index
mlir::SmallVector<int64_t> idx =
map(idxBounds, [](BoundsAndStep &bas) { return bas.lb; });
// Maps the index of each IV in the loop nest to the indexes of
// the extract operation
mlir::SmallVector<mlir::SmallVector<size_t>> revIdxMap(nest.size());
for (size_t i = 0; i < idxMap.size(); i++) {
for (size_t j = 0; j < nest.size(); j++) {
if (nest[j] == idxMap[i]) {
revIdxMap[j].push_back(i);
break;
}
}
}
size_t i = nest.size() - 1;
// Reset the trip count for a loop back to zero and reinitializes
// all indexes using the associated IV
auto resetTrips = [&](size_t loopIdx) {
trips[loopIdx] = 0;
for (size_t i : revIdxMap[loopIdx]) {
idx[i] = idxBounds[i].lb;
}
};
// Increases the trip count of a loop by one and calculates the
// next value of all indexes using the associated IV
auto incTrips = [&](size_t loopIdx) {
trips[loopIdx] += 1;
for (size_t i : revIdxMap[loopIdx]) {
idx[i] += idxBounds[i].step;
}
};
// Iterate over the entire iteration space of the loop nest. The
// variable i represents the index of the loop that is currently
// stepped in the nest
while (true) {
// Loop has reached its maximum trip count. If the loop ist the
// first in the nest, the entire space has been
// covered. Otherwise, reset the trip count of the current loop
// and step the loop above.
if (trips[i] == tripCounts[i]) {
if (i == 0)
break;
resetTrips(i);
i--;
incTrips(i);
} else {
// Trip count of the current loop hasn't been reached. If this
// is the innermost loop, calculate a new index, fold all
// values and write the result to the new tensor of
// constants. Otherwise, switch to the next loop in the nest.
if (i == nest.size() - 1) {
size_t flatIdx = linearizeIndex(constantType.getShape(), idx);
mlir::Attribute newVal =
fold(targetOp, extractOp.getResult(), denseVals[flatIdx]);
newDenseVals[flatIdx] = newVal;
incTrips(i);
} else {
i++;
}
}
}
// Create a new `arith.constant` operation with the updated tensor
// of constants
mlir::DenseElementsAttr newDenseElementsAttr = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get(constantType.getShape(), foldedElementType),
newDenseVals);
rewriter.setInsertionPoint(cdo);
mlir::arith::ConstantOp newCdo = rewriter.create<mlir::arith::ConstantOp>(
cdo.getLoc(), newDenseElementsAttr);
rewriter.setInsertionPoint(targetOp);
// Replace the last op in the chain of foldable operations with a
// `tensor.extract` op on the new tensor of constants.
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractOp>(
targetOp, targetOp->getResult(0).getType(), newCdo,
extractOp.getIndices());
return mlir::success();
}
};
class TensorAllocationCleanupPattern
: public mlir::OpRewritePattern<mlir::func::FuncOp> {
public:
@@ -2023,7 +2406,8 @@ public:
.add<CleanupPattern<mlir::tensor::ExtractOp, mlir::tensor::InsertOp>,
CleanupPattern<mlir::tensor::ExtractSliceOp,
mlir::tensor::InsertSliceOp>,
TensorAllocationCleanupPattern>(op->getContext());
ConstantDenseFoldingPattern, TensorAllocationCleanupPattern>(
op->getContext());
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed())
this->signalPassFailure();