mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Batching: Hoist non-batchable operands produced by pure ops
The batching pass only creates a batched version of a batchable operation if all of its non-batchable operands are defined out ouf the outermost loop the iterating over the values of the batchable operand. This change also allows for operations to be batched if the non-batachable operands are generated by operations, which are pure and thus hoistable out of the outermost loop.
This commit is contained in:
@@ -40,7 +40,8 @@ def BatchableOpInterface : OpInterface<"BatchableOpInterface"> {
|
||||
/*retTy=*/"::mlir::Value",
|
||||
/*methodName=*/"createBatchedOperation",
|
||||
/*args=*/(ins "::mlir::ImplicitLocOpBuilder&":$builder,
|
||||
"::mlir::Value":$batchedOperands),
|
||||
"::mlir::Value":$batchedOperands,
|
||||
"::mlir::ValueRange":$hoistedNonBatchableOperands),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
llvm_unreachable("createBatchedOperation not implemented");
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
#include <concretelang/Interfaces/BatchableInterface.h>
|
||||
@@ -15,6 +16,20 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Checks if the value `v` is defined outside of the `loop` or a pure
|
||||
/// operation that can be safely replicated ouside the loop (i.e., all
|
||||
/// of its operands are also recursively either defined outside of the
|
||||
/// loop or pure).
|
||||
static bool isHoistable(mlir::Value v, mlir::scf::ForOp loop) {
|
||||
mlir::Operation *op = v.getDefiningOp();
|
||||
|
||||
return loop.isDefinedOutsideOfLoop(v) ||
|
||||
(op && mlir::isPure(op) && op->getNumResults() == 1 &&
|
||||
llvm::all_of(op->getOperands(), [&](mlir::Value operand) {
|
||||
return isHoistable(operand, loop);
|
||||
}));
|
||||
}
|
||||
|
||||
/// Checks if `forOp` has constant bounds and a constant step.
|
||||
static bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb = nullptr,
|
||||
int64_t *iub = nullptr, int64_t *istep = nullptr) {
|
||||
@@ -383,6 +398,39 @@ isQuasiAffineIVExpressionWithConstantStep(mlir::Value expr,
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Hoists the pure operation producing the value `v` out of
|
||||
/// `outermostFor` recursively. All newly created mappings are
|
||||
/// collected in `mapping`.
|
||||
static mlir::Value hoistPure(mlir::PatternRewriter &rewriter,
|
||||
mlir::scf::ForOp outermostFor,
|
||||
mlir::IRMapping &mapping, mlir::Value v) {
|
||||
if (outermostFor.isDefinedOutsideOfLoop(v))
|
||||
return v;
|
||||
|
||||
mlir::Operation *op = v.getDefiningOp();
|
||||
|
||||
assert(op && mlir::isPure(op) && op->getNumResults() == 1);
|
||||
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
if (!mapping.contains(operand))
|
||||
mapping.map(operand, hoistPure(rewriter, outermostFor, mapping, operand));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(outermostFor);
|
||||
|
||||
mlir::Operation *clonedOp = rewriter.clone(*op, mapping);
|
||||
|
||||
return clonedOp->getResult(0);
|
||||
}
|
||||
|
||||
/// Hoists the pure operation producing the value `v` out of
|
||||
/// `outermostFor` recursively.
|
||||
static mlir::Value hoistPure(mlir::PatternRewriter &rewriter,
|
||||
mlir::scf::ForOp outermostFor, mlir::Value v) {
|
||||
mlir::IRMapping mapping;
|
||||
return hoistPure(rewriter, outermostFor, mapping, v);
|
||||
}
|
||||
|
||||
/// Hoists a an operation embedded into a loop nest that and that is
|
||||
/// indexed using quasi-affine expressions of the loops' IVs as a
|
||||
/// `tensor.extract_slice` outside of the outermost loop
|
||||
@@ -618,9 +666,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that other args are defined outside the loop nest
|
||||
// Verify that other args are defined outside the loop nest or
|
||||
// hoistable
|
||||
if (!llvm::all_of(scalarOp.getNonBatchableOperands(), [&](mlir::Value v) {
|
||||
return currOutermostFor.isDefinedOutsideOfLoop(v);
|
||||
return isHoistable(v, currOutermostFor);
|
||||
})) {
|
||||
return mlir::WalkResult::skip();
|
||||
}
|
||||
@@ -666,11 +715,18 @@ public:
|
||||
llvm::SmallVector<mlir::ReassociationIndices>{indices});
|
||||
}
|
||||
|
||||
// Hoist all non-batchable operands
|
||||
llvm::SmallVector<mlir::Value> hoistedNonBatchableOperands;
|
||||
for (mlir::Value operand : targetOp.getNonBatchableOperands()) {
|
||||
hoistedNonBatchableOperands.push_back(
|
||||
hoistPure(rewriter, outermostFor, operand));
|
||||
}
|
||||
|
||||
// Create the batched operation and pass flattened, batched
|
||||
// operands
|
||||
mlir::ImplicitLocOpBuilder ilob(targetExtractOp.getLoc(), rewriter);
|
||||
mlir::Value batchedOpResult =
|
||||
targetOp.createBatchedOperation(ilob, flattenedSlice);
|
||||
mlir::Value batchedOpResult = targetOp.createBatchedOperation(
|
||||
ilob, flattenedSlice, hoistedNonBatchableOperands);
|
||||
|
||||
mlir::Value expandedBatchResultTensor;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user