mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Batching: Add pattern folding operations on tensors of constants
This adds a new pattern to the batching pass that folds operations on
tensors of constants into new tensors of constants. E.g.,
%cst = arith.constant dense<...> : tensor<Nxi9>
%res = scf.for %i = %c0 to %cN {
%cst_i9 = tensor.extract %cst[%i]
%cst_i64 = arith.extui %cst_i9 : i64
...
}
becomes:
%cst = arith.constant dense<...> : tensor<Nxi64>
%res = scf.for %i = %c0 to %cN {
%cst_i64 = tensor.extract %cst[%i]
...
}
The pattern only works for static loops, indexes that are quasi-affine
expressions on single loop induction variables with a constant step
size across iterations and foldable operations that have a single
result.
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/ADT/TypeSwitch.h>
|
||||
#include <mlir/Dialect/Affine/IR/AffineOps.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
@@ -215,6 +216,16 @@ sortLoopsInnermostToOutermost(ContainerTy &forOps) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Sorts the `scf.for` loops from `forOps` from the outermost to the
|
||||
// innermost loop. The loops must be embedded one into another.
|
||||
template <typename ContainerTy>
|
||||
SmallVector<mlir::scf::ForOp>
|
||||
sortLoopsOutermostToInnermost(ContainerTy &forOps) {
|
||||
SmallVector<mlir::scf::ForOp> nest = sortLoopsInnermostToOutermost(forOps);
|
||||
std::reverse(nest.begin(), nest.end());
|
||||
return nest;
|
||||
}
|
||||
|
||||
// Takes a set of loops `forOps` and finds the longest sequence of
|
||||
// perfectly nested loops starting with innermost loops of `forOps`,
|
||||
// in which the predicate `parentChildPredicate` holds for all loops
|
||||
@@ -1988,6 +1999,378 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Folds the operation `op` by recursively folding all
|
||||
// producers. Occurrences of `arg` are replaced with `argVal`.
|
||||
// All encountered operations must produce a single result.
|
||||
static mlir::Attribute fold(mlir::Operation *op, mlir::Value arg,
|
||||
mlir::Attribute argVal) {
|
||||
assert(op->getNumResults() == 1);
|
||||
|
||||
// Check if `arg` needs to be replaced with `argVal`
|
||||
if (op->getResult(0) == arg)
|
||||
return argVal;
|
||||
|
||||
// Constants are just folded to their value attributes
|
||||
if (mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast<mlir::arith::ConstantOp>(op)) {
|
||||
return cstOp.getValue();
|
||||
}
|
||||
|
||||
// Recursively fold all producers and collect the folding results
|
||||
// for each operand
|
||||
llvm::SmallVector<mlir::Attribute> foldedOperands;
|
||||
|
||||
for (mlir::OpOperand &operand : op->getOpOperands()) {
|
||||
mlir::Operation *producer = operand.get().getDefiningOp();
|
||||
assert(producer);
|
||||
|
||||
foldedOperands.push_back(fold(producer, arg, argVal));
|
||||
}
|
||||
|
||||
// Invoke the folder for this operation
|
||||
llvm::SmallVector<mlir::OpFoldResult> res;
|
||||
mlir::LogicalResult foldRes = op->fold(foldedOperands, res);
|
||||
|
||||
assert(foldRes.succeeded());
|
||||
assert(res.size() == 1);
|
||||
|
||||
mlir::Attribute resAttr = res[0].dyn_cast<mlir::Attribute>();
|
||||
|
||||
assert(resAttr);
|
||||
|
||||
return resAttr;
|
||||
}
|
||||
|
||||
// Folding pattern that collapses operations on constant dense tensors
|
||||
// into a new constant. E.g.,
|
||||
//
|
||||
// %cst = arith.constant dense<...> : tensor<Nxi9>
|
||||
// %res = scf.for %i = %c0 to %cN {
|
||||
// %cst_i9 = tensor.extract %cst[%i]
|
||||
// %cst_i64 = arith.extui %cst_i9 : i64
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %cst = arith.constant dense<...> : tensor<Nxi64>
|
||||
// %res = scf.for %i = %c0 to %cN {
|
||||
// %cst_i64 = tensor.extract %cst[%i]
|
||||
// ...
|
||||
// }
|
||||
class ConstantDenseFoldingPattern
|
||||
: public mlir::OpRewritePattern<mlir::func::FuncOp> {
|
||||
protected:
|
||||
// Checks if an operation is foldable
|
||||
static bool isFoldableOp(mlir::Operation *op) {
|
||||
return op->getNumResults() == 1 && mlir::isPure(op) &&
|
||||
llvm::TypeSwitch<mlir::Operation *, bool>(op)
|
||||
.Case<mlir::arith::AddIOp, mlir::arith::ExtSIOp,
|
||||
mlir::arith::ConstantOp>([](auto op) { return true; })
|
||||
.Default([](auto op) { return false; });
|
||||
}
|
||||
|
||||
// Checks if `v` can be calculated statically given that the values
|
||||
// in `foldables` are static. The function recursively collects all
|
||||
// intermediate values which have been found to be static in
|
||||
// `foldables`.
|
||||
static bool isFoldableValue(mlir::Value v,
|
||||
llvm::DenseSet<mlir::Value> &foldables) {
|
||||
if (foldables.contains(v))
|
||||
return true;
|
||||
|
||||
mlir::Operation *op = v.getDefiningOp();
|
||||
|
||||
if (!op || !isFoldableOp(op))
|
||||
return false;
|
||||
|
||||
if (llvm::all_of(op->getOperands(), [&](mlir::Value v) {
|
||||
return isFoldableValue(v, foldables);
|
||||
})) {
|
||||
for (mlir::Value v : op->getOperands())
|
||||
foldables.insert(v);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Generates a flat index from a tensor with the shape `shape`
|
||||
// indexed by `idx`
|
||||
static int64_t linearizeIndex(llvm::ArrayRef<int64_t> shape,
|
||||
llvm::ArrayRef<int64_t> idx) {
|
||||
int64_t flatIdx = 0;
|
||||
int64_t mul = 1;
|
||||
int64_t n = shape.size();
|
||||
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
flatIdx += mul * idx[n - i - 1];
|
||||
mul *= shape[n - i - 1];
|
||||
}
|
||||
|
||||
return flatIdx;
|
||||
};
|
||||
|
||||
public:
|
||||
ConstantDenseFoldingPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<mlir::func::FuncOp>(context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::func::FuncOp func,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::tensor::ExtractOp extractOp;
|
||||
mlir::Operation *targetOp = nullptr;
|
||||
llvm::SmallVector<mlir::scf::ForOp> nest;
|
||||
llvm::SmallVector<mlir::scf::ForOp> idxMap;
|
||||
llvm::SmallVector<BoundsAndStep> idxBounds;
|
||||
mlir::arith::ConstantOp cdo;
|
||||
mlir::RankedTensorType constantType;
|
||||
mlir::Type origElementType;
|
||||
mlir::Type foldedElementType;
|
||||
|
||||
func.walk([&](mlir::tensor::ExtractOp currExtractOp) {
|
||||
// Check that the extraction in on a value produced by an
|
||||
// `arith.constant_dense` operation.
|
||||
mlir::arith::ConstantOp currCdo =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
currExtractOp.getTensor().getDefiningOp());
|
||||
|
||||
if (!currCdo)
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
if (!isSoleUser(currCdo.getResult(), currExtractOp))
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
mlir::RankedTensorType currConstantType =
|
||||
currExtractOp.getTensor()
|
||||
.getType()
|
||||
.dyn_cast<mlir::RankedTensorType>();
|
||||
|
||||
if (!currConstantType)
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
// First check that the extract op is embedded in a for loop
|
||||
mlir::scf::ForOp currInnermostFor =
|
||||
llvm::dyn_cast<mlir::scf::ForOp>(currExtractOp->getParentOp());
|
||||
|
||||
if (!currInnermostFor)
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
llvm::DenseSet<mlir::scf::ForOp> nestUnsorted;
|
||||
llvm::SmallVector<mlir::scf::ForOp> currIdxMap;
|
||||
llvm::SmallVector<BoundsAndStep> currIdxBounds;
|
||||
|
||||
// Make sure that the extract operation uses only quasi affine
|
||||
// expressions on IVs, where each index uses at most a single
|
||||
// IV.
|
||||
for (mlir::Value idx : currExtractOp.getIndices()) {
|
||||
mlir::scf::ForOp forOp;
|
||||
BoundsAndStep bas;
|
||||
|
||||
if (!isQuasiAffineIVExpressionWithConstantStep(idx, &forOp, &bas))
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
if (forOp)
|
||||
nestUnsorted.insert(forOp);
|
||||
|
||||
currIdxBounds.push_back(bas);
|
||||
currIdxMap.push_back(forOp);
|
||||
}
|
||||
|
||||
llvm::DenseSet<mlir::Value> foldables;
|
||||
foldables.insert(currExtractOp.getResult());
|
||||
|
||||
if (!currExtractOp.getResult().hasOneUse())
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
mlir::Operation *firstUser =
|
||||
currExtractOp.getResult().getUses().begin()->getOwner();
|
||||
mlir::Operation *currOp = firstUser;
|
||||
mlir::Operation *currTargetOp = nullptr;
|
||||
mlir::Type currOrigElementType = currConstantType.getElementType();
|
||||
mlir::Type currFoldedElementType = currOrigElementType;
|
||||
|
||||
// Walk down the def-use chain from the extract operaion until
|
||||
// an operation is found that is not foldable
|
||||
while (true) {
|
||||
if (!isFoldableOp(currOp))
|
||||
break;
|
||||
|
||||
if (currOp->getNumResults() != 1 || !currOp->getResult(0).hasOneUse() ||
|
||||
currOp->getParentOp() != currInnermostFor.getOperation())
|
||||
break;
|
||||
|
||||
if (!llvm::all_of(currOp->getOperands(), [&](mlir::Value v) {
|
||||
return isFoldableValue(v, foldables);
|
||||
}))
|
||||
break;
|
||||
|
||||
currFoldedElementType = currOp->getResult(0).getType();
|
||||
|
||||
currTargetOp = currOp;
|
||||
currOp = currOp->getUses().begin()->getOwner();
|
||||
}
|
||||
|
||||
if (!currTargetOp)
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
// Check constraints on the index space of the extract
|
||||
// operation. I.e., if the type changes during teh folding,
|
||||
// ensure that the index space covers the entire tensor and that
|
||||
// there are no out-of-bounds accesses.
|
||||
for (auto it : llvm::enumerate(currExtractOp.getIndices())) {
|
||||
mlir::scf::ForOp forOp;
|
||||
BoundsAndStep bas;
|
||||
mlir::Value idx = it.value();
|
||||
size_t i = it.index();
|
||||
|
||||
if (!isQuasiAffineIVExpressionWithConstantStep(idx, &forOp, &bas))
|
||||
return mlir::WalkResult::skip();
|
||||
|
||||
// If the type changes by the folding, the entire tensor needs
|
||||
// to be rewritten
|
||||
if (currFoldedElementType != currOrigElementType) {
|
||||
if (bas.lb != 0 || bas.ub != currConstantType.getDimSize(i) ||
|
||||
bas.step != 1)
|
||||
return mlir::WalkResult::skip();
|
||||
}
|
||||
// Otherwise, just make sure that there are no out-of-bounds
|
||||
// accesses
|
||||
else {
|
||||
if (bas.ub - bas.step >= currConstantType.getDimSize(i))
|
||||
return mlir::WalkResult::skip();
|
||||
}
|
||||
}
|
||||
|
||||
extractOp = currExtractOp;
|
||||
targetOp = currTargetOp;
|
||||
|
||||
nest = sortLoopsOutermostToInnermost(nestUnsorted);
|
||||
idxMap = std::move(currIdxMap);
|
||||
idxBounds = std::move(currIdxBounds);
|
||||
cdo = currCdo;
|
||||
constantType = currConstantType;
|
||||
origElementType = currOrigElementType;
|
||||
foldedElementType = currFoldedElementType;
|
||||
|
||||
return mlir::WalkResult::interrupt();
|
||||
});
|
||||
|
||||
if (!targetOp)
|
||||
return mlir::failure();
|
||||
|
||||
// Original tensor of constants
|
||||
auto denseVals = cdo.getValueAttr()
|
||||
.cast<mlir::DenseElementsAttr>()
|
||||
.getValues<mlir::Attribute>();
|
||||
|
||||
// Updated tensor of constants intialized with original values
|
||||
SmallVector<mlir::Attribute> newDenseVals(denseVals.begin(),
|
||||
denseVals.end());
|
||||
|
||||
mlir::SmallVector<int64_t> tripCounts = map(
|
||||
nest, [](mlir::scf::ForOp forOp) { return getStaticTripCount(forOp); });
|
||||
|
||||
// Number of iterations already performed for each loop
|
||||
mlir::SmallVector<int64_t> trips(nest.size(), 0);
|
||||
|
||||
// current index
|
||||
mlir::SmallVector<int64_t> idx =
|
||||
map(idxBounds, [](BoundsAndStep &bas) { return bas.lb; });
|
||||
|
||||
// Maps the index of each IV in the loop nest to the indexes of
|
||||
// the extract operation
|
||||
mlir::SmallVector<mlir::SmallVector<size_t>> revIdxMap(nest.size());
|
||||
|
||||
for (size_t i = 0; i < idxMap.size(); i++) {
|
||||
for (size_t j = 0; j < nest.size(); j++) {
|
||||
if (nest[j] == idxMap[i]) {
|
||||
revIdxMap[j].push_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t i = nest.size() - 1;
|
||||
|
||||
// Reset the trip count for a loop back to zero and reinitializes
|
||||
// all indexes using the associated IV
|
||||
auto resetTrips = [&](size_t loopIdx) {
|
||||
trips[loopIdx] = 0;
|
||||
|
||||
for (size_t i : revIdxMap[loopIdx]) {
|
||||
idx[i] = idxBounds[i].lb;
|
||||
}
|
||||
};
|
||||
|
||||
// Increases the trip count of a loop by one and calculates the
|
||||
// next value of all indexes using the associated IV
|
||||
auto incTrips = [&](size_t loopIdx) {
|
||||
trips[loopIdx] += 1;
|
||||
|
||||
for (size_t i : revIdxMap[loopIdx]) {
|
||||
idx[i] += idxBounds[i].step;
|
||||
}
|
||||
};
|
||||
|
||||
// Iterate over the entire iteration space of the loop nest. The
|
||||
// variable i represents the index of the loop that is currently
|
||||
// stepped in the nest
|
||||
while (true) {
|
||||
// Loop has reached its maximum trip count. If the loop ist the
|
||||
// first in the nest, the entire space has been
|
||||
// covered. Otherwise, reset the trip count of the current loop
|
||||
// and step the loop above.
|
||||
if (trips[i] == tripCounts[i]) {
|
||||
if (i == 0)
|
||||
break;
|
||||
|
||||
resetTrips(i);
|
||||
i--;
|
||||
|
||||
incTrips(i);
|
||||
} else {
|
||||
// Trip count of the current loop hasn't been reached. If this
|
||||
// is the innermost loop, calculate a new index, fold all
|
||||
// values and write the result to the new tensor of
|
||||
// constants. Otherwise, switch to the next loop in the nest.
|
||||
if (i == nest.size() - 1) {
|
||||
size_t flatIdx = linearizeIndex(constantType.getShape(), idx);
|
||||
|
||||
mlir::Attribute newVal =
|
||||
fold(targetOp, extractOp.getResult(), denseVals[flatIdx]);
|
||||
|
||||
newDenseVals[flatIdx] = newVal;
|
||||
incTrips(i);
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new `arith.constant` operation with the updated tensor
|
||||
// of constants
|
||||
mlir::DenseElementsAttr newDenseElementsAttr = mlir::DenseElementsAttr::get(
|
||||
mlir::RankedTensorType::get(constantType.getShape(), foldedElementType),
|
||||
newDenseVals);
|
||||
|
||||
rewriter.setInsertionPoint(cdo);
|
||||
mlir::arith::ConstantOp newCdo = rewriter.create<mlir::arith::ConstantOp>(
|
||||
cdo.getLoc(), newDenseElementsAttr);
|
||||
|
||||
rewriter.setInsertionPoint(targetOp);
|
||||
|
||||
// Replace the last op in the chain of foldable operations with a
|
||||
// `tensor.extract` op on the new tensor of constants.
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractOp>(
|
||||
targetOp, targetOp->getResult(0).getType(), newCdo,
|
||||
extractOp.getIndices());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class TensorAllocationCleanupPattern
|
||||
: public mlir::OpRewritePattern<mlir::func::FuncOp> {
|
||||
public:
|
||||
@@ -2023,7 +2406,8 @@ public:
|
||||
.add<CleanupPattern<mlir::tensor::ExtractOp, mlir::tensor::InsertOp>,
|
||||
CleanupPattern<mlir::tensor::ExtractSliceOp,
|
||||
mlir::tensor::InsertSliceOp>,
|
||||
TensorAllocationCleanupPattern>(op->getContext());
|
||||
ConstantDenseFoldingPattern, TensorAllocationCleanupPattern>(
|
||||
op->getContext());
|
||||
|
||||
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed())
|
||||
this->signalPassFailure();
|
||||
|
||||
Reference in New Issue
Block a user