mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(compiler): Custom copy op from 1D tensor to avoir stack allocation from mlir memref to llvm lowering
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/Analysis/Autopar.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
@@ -39,6 +40,61 @@ struct MLIRLowerableDialectsToLLVMPass
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This rewrite pattern transforms any instance of `memref.copy`
|
||||
// operators on 1D memref.
|
||||
// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked
|
||||
// memref that basically allocate unranked memref structure on the stack before
|
||||
// calling @memrefCopy.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// memref.copy %src, %dst : memref<Xxi64> to memref<Xxi64>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %_src = memref.cast %src = memref<Xxi64> to memref<?xi64>
|
||||
// %_dst = memref.cast %dst = memref<Xxi64> to memref<?xi64>
|
||||
// call @memref_copy_one_rank(%_src, %_dst) : (tensor<?xi64>, tensor<?xi64>) ->
|
||||
// ()
|
||||
// ```
|
||||
struct Memref1DCopyOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
|
||||
Memref1DCopyOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::memref::CopyOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::memref::CopyOp copyOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1 ||
|
||||
copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type());
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
if (insertForwardDeclaration(
|
||||
copyOp, rewriter, "memref_copy_one_rank",
|
||||
mlir::FunctionType::get(rewriter.getContext(), {opType, opType},
|
||||
{}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
auto sourceOp = rewriter.create<mlir::memref::CastOp>(
|
||||
copyOp.getLoc(), opType, copyOp.source());
|
||||
auto targetOp = rewriter.create<mlir::memref::CastOp>(
|
||||
copyOp.getLoc(), opType, copyOp.target());
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
copyOp, "memref_copy_one_rank", mlir::TypeRange{},
|
||||
mlir::ValueRange{sourceOp, targetOp});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the conversion target. We reuse the LLVMConversionTarget that
|
||||
// legalize LLVM dialect.
|
||||
@@ -63,6 +119,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the set of the patterns rewriter. At this point we want to
|
||||
// convert the `scf` operations to `std` and `std` operations to `llvm`.
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<Memref1DCopyOpPattern>(&getContext(), 100);
|
||||
mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
Reference in New Issue
Block a user