fix(compiler): Custom copy op from 1D tensor to avoir stack allocation from mlir memref to llvm lowering

This commit is contained in:
Quentin Bourgerie
2022-04-07 23:49:30 +02:00
parent af300055a7
commit 247d60503d
9 changed files with 137 additions and 29 deletions

View File

@@ -25,6 +25,7 @@
#include "llvm/ADT/Sequence.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/Analysis/Autopar.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
@@ -39,6 +40,61 @@ struct MLIRLowerableDialectsToLLVMPass
};
} // namespace
// This rewrite pattern transforms any instance of `memref.copy`
// operators on 1D memref.
// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked
// memref that basically allocate unranked memref structure on the stack before
// calling @memrefCopy.
//
// Example:
//
// ```mlir
// memref.copy %src, %dst : memref<Xxi64> to memref<Xxi64>
// ```
//
// becomes:
//
// ```mlir
// %_src = memref.cast %src = memref<Xxi64> to memref<?xi64>
// %_dst = memref.cast %dst = memref<Xxi64> to memref<?xi64>
// call @memref_copy_one_rank(%_src, %_dst) : (tensor<?xi64>, tensor<?xi64>) ->
// ()
// ```
struct Memref1DCopyOpPattern
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
Memref1DCopyOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::memref::CopyOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::memref::CopyOp copyOp,
mlir::PatternRewriter &rewriter) const override {
if (copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1 ||
copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1) {
return mlir::failure();
}
auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type());
// Insert forward declaration of the add_lwe_ciphertexts function
{
if (insertForwardDeclaration(
copyOp, rewriter, "memref_copy_one_rank",
mlir::FunctionType::get(rewriter.getContext(), {opType, opType},
{}))
.failed()) {
return mlir::failure();
}
}
auto sourceOp = rewriter.create<mlir::memref::CastOp>(
copyOp.getLoc(), opType, copyOp.source());
auto targetOp = rewriter.create<mlir::memref::CastOp>(
copyOp.getLoc(), opType, copyOp.target());
rewriter.replaceOpWithNewOp<mlir::CallOp>(
copyOp, "memref_copy_one_rank", mlir::TypeRange{},
mlir::ValueRange{sourceOp, targetOp});
return mlir::success();
};
};
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the conversion target. We reuse the LLVMConversionTarget that
// legalize LLVM dialect.
@@ -63,6 +119,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the set of the patterns rewriter. At this point we want to
// convert the `scf` operations to `std` and `std` operations to `llvm`.
mlir::RewritePatternSet patterns(&getContext());
patterns.add<Memref1DCopyOpPattern>(&getContext(), 100);
mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);