feat(compiler): Batching: Hoist non-batchable operands produced by pure ops

The batching pass only creates a batched version of a batchable
operation if all of its non-batchable operands are defined out ouf the
outermost loop the iterating over the values of the batchable operand.

This change also allows for operations to be batched if the
non-batachable operands are generated by operations, which are pure
and thus hoistable out of the outermost loop.
This commit is contained in:
Andi Drebes
2023-03-14 17:14:40 +01:00
parent 3309615d7b
commit b24709a1ec
2 changed files with 62 additions and 5 deletions

View File

@@ -40,7 +40,8 @@ def BatchableOpInterface : OpInterface<"BatchableOpInterface"> {
/*retTy=*/"::mlir::Value",
/*methodName=*/"createBatchedOperation",
/*args=*/(ins "::mlir::ImplicitLocOpBuilder&":$builder,
"::mlir::Value":$batchedOperands),
"::mlir::Value":$batchedOperands,
"::mlir::ValueRange":$hoistedNonBatchableOperands),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("createBatchedOperation not implemented");

View File

@@ -8,6 +8,7 @@
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <concretelang/Interfaces/BatchableInterface.h>
@@ -15,6 +16,20 @@
namespace mlir {
namespace concretelang {
/// Checks if the value `v` is defined outside of the `loop` or a pure
/// operation that can be safely replicated ouside the loop (i.e., all
/// of its operands are also recursively either defined outside of the
/// loop or pure).
static bool isHoistable(mlir::Value v, mlir::scf::ForOp loop) {
mlir::Operation *op = v.getDefiningOp();
return loop.isDefinedOutsideOfLoop(v) ||
(op && mlir::isPure(op) && op->getNumResults() == 1 &&
llvm::all_of(op->getOperands(), [&](mlir::Value operand) {
return isHoistable(operand, loop);
}));
}
/// Checks if `forOp` has constant bounds and a constant step.
static bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb = nullptr,
int64_t *iub = nullptr, int64_t *istep = nullptr) {
@@ -383,6 +398,39 @@ isQuasiAffineIVExpressionWithConstantStep(mlir::Value expr,
return false;
}
/// Hoists the pure operation producing the value `v` out of
/// `outermostFor` recursively. All newly created mappings are
/// collected in `mapping`.
static mlir::Value hoistPure(mlir::PatternRewriter &rewriter,
mlir::scf::ForOp outermostFor,
mlir::IRMapping &mapping, mlir::Value v) {
if (outermostFor.isDefinedOutsideOfLoop(v))
return v;
mlir::Operation *op = v.getDefiningOp();
assert(op && mlir::isPure(op) && op->getNumResults() == 1);
for (mlir::Value operand : op->getOperands()) {
if (!mapping.contains(operand))
mapping.map(operand, hoistPure(rewriter, outermostFor, mapping, operand));
}
rewriter.setInsertionPoint(outermostFor);
mlir::Operation *clonedOp = rewriter.clone(*op, mapping);
return clonedOp->getResult(0);
}
/// Hoists the pure operation producing the value `v` out of
/// `outermostFor` recursively.
static mlir::Value hoistPure(mlir::PatternRewriter &rewriter,
mlir::scf::ForOp outermostFor, mlir::Value v) {
mlir::IRMapping mapping;
return hoistPure(rewriter, outermostFor, mapping, v);
}
/// Hoists a an operation embedded into a loop nest that and that is
/// indexed using quasi-affine expressions of the loops' IVs as a
/// `tensor.extract_slice` outside of the outermost loop
@@ -618,9 +666,10 @@ public:
}
}
// Verify that other args are defined outside the loop nest
// Verify that other args are defined outside the loop nest or
// hoistable
if (!llvm::all_of(scalarOp.getNonBatchableOperands(), [&](mlir::Value v) {
return currOutermostFor.isDefinedOutsideOfLoop(v);
return isHoistable(v, currOutermostFor);
})) {
return mlir::WalkResult::skip();
}
@@ -666,11 +715,18 @@ public:
llvm::SmallVector<mlir::ReassociationIndices>{indices});
}
// Hoist all non-batchable operands
llvm::SmallVector<mlir::Value> hoistedNonBatchableOperands;
for (mlir::Value operand : targetOp.getNonBatchableOperands()) {
hoistedNonBatchableOperands.push_back(
hoistPure(rewriter, outermostFor, operand));
}
// Create the batched operation and pass flattened, batched
// operands
mlir::ImplicitLocOpBuilder ilob(targetExtractOp.getLoc(), rewriter);
mlir::Value batchedOpResult =
targetOp.createBatchedOperation(ilob, flattenedSlice);
mlir::Value batchedOpResult = targetOp.createBatchedOperation(
ilob, flattenedSlice, hoistedNonBatchableOperands);
mlir::Value expandedBatchResultTensor;