fix(compiler): Custom copy op from 1D tensor to avoir stack allocation from mlir memref to llvm lowering

This commit is contained in:
Quentin Bourgerie
2022-04-07 23:49:30 +02:00
parent af300055a7
commit 247d60503d
9 changed files with 137 additions and 29 deletions

View File

@@ -0,0 +1,11 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "mlir/IR/PatternMatch.h"
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::RewriterBase &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType);

View File

@@ -57,6 +57,12 @@ void memref_bootstrap_lwe_u64(
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
mlir::concretelang::RuntimeContext *context);
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
uint64_t src_offset, uint64_t src_size,
uint64_t src_stride, uint64_t *dst_allocated,
uint64_t *dst_aligned, uint64_t dst_offset,
uint64_t dst_size, uint64_t dst_stride);
}
#endif

View File

@@ -12,6 +12,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
@@ -34,35 +35,6 @@ public:
}
};
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::RewriterBase &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
mlir::SymbolTable::lookupSymbolIn(module, funcName));
if (!opFunc) {
// Insert the forward declaration of the funcName
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcType);
opFunc.setPrivate();
} else {
// Check if the `funcName` is well a private function
if (!opFunc.isPrivate()) {
op->emitError() << "the function \"" << funcName
<< "\" conflicts with the concrete C API, please rename";
return mlir::failure();
}
}
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
->template hasTrait<mlir::OpTrait::FunctionLike>());
return mlir::success();
}
// Set of functions to generate generic types.
// Generic types are used to add forward declarations without a specific type.
// For example, we may need to add LWE ciphertext of different dimensions, or

View File

@@ -9,6 +9,7 @@ add_mlir_dialect_library(BConcreteToBConcreteCAPI
mlir-headers
LINK_LIBS PUBLIC
ConcretelangConversion
MLIRIR
MLIRTransforms
)

View File

@@ -5,3 +5,10 @@ add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(ConcreteToBConcrete)
add_subdirectory(BConcreteToBConcreteCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_mlir_library(ConcretelangConversion
Tools.cpp
LINK_LIBS PUBLIC
MLIRIR
)

View File

@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRLowerableDialectsToLLVM
LINK_LIBS PUBLIC
${dialect_libs}
${conversion_libs}
ConcretelangConversion
MLIRIR
MLIRTransforms
MLIRLLVMIR

View File

@@ -25,6 +25,7 @@
#include "llvm/ADT/Sequence.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/Analysis/Autopar.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
@@ -39,6 +40,61 @@ struct MLIRLowerableDialectsToLLVMPass
};
} // namespace
// This rewrite pattern transforms any instance of `memref.copy`
// operators on 1D memref.
// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked
// memref that basically allocate unranked memref structure on the stack before
// calling @memrefCopy.
//
// Example:
//
// ```mlir
// memref.copy %src, %dst : memref<Xxi64> to memref<Xxi64>
// ```
//
// becomes:
//
// ```mlir
// %_src = memref.cast %src = memref<Xxi64> to memref<?xi64>
// %_dst = memref.cast %dst = memref<Xxi64> to memref<?xi64>
// call @memref_copy_one_rank(%_src, %_dst) : (tensor<?xi64>, tensor<?xi64>) ->
// ()
// ```
struct Memref1DCopyOpPattern
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
Memref1DCopyOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::memref::CopyOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::memref::CopyOp copyOp,
mlir::PatternRewriter &rewriter) const override {
if (copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1 ||
copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1) {
return mlir::failure();
}
auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type());
// Insert forward declaration of the add_lwe_ciphertexts function
{
if (insertForwardDeclaration(
copyOp, rewriter, "memref_copy_one_rank",
mlir::FunctionType::get(rewriter.getContext(), {opType, opType},
{}))
.failed()) {
return mlir::failure();
}
}
auto sourceOp = rewriter.create<mlir::memref::CastOp>(
copyOp.getLoc(), opType, copyOp.source());
auto targetOp = rewriter.create<mlir::memref::CastOp>(
copyOp.getLoc(), opType, copyOp.target());
rewriter.replaceOpWithNewOp<mlir::CallOp>(
copyOp, "memref_copy_one_rank", mlir::TypeRange{},
mlir::ValueRange{sourceOp, targetOp});
return mlir::success();
};
};
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the conversion target. We reuse the LLVMConversionTarget that
// legalize LLVM dialect.
@@ -63,6 +119,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the set of the patterns rewriter. At this point we want to
// convert the `scf` operations to `std` and `std` operations to `llvm`.
mlir::RewritePatternSet patterns(&getContext());
patterns.add<Memref1DCopyOpPattern>(&getContext(), 100);
mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);

View File

@@ -0,0 +1,35 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Tools.h"
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::RewriterBase &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
mlir::SymbolTable::lookupSymbolIn(module, funcName));
if (!opFunc) {
// Insert the forward declaration of the funcName
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcType);
opFunc.setPrivate();
} else {
// Check if the `funcName` is well a private function
if (!opFunc.isPrivate()) {
op->emitError() << "the function \"" << funcName
<< "\" conflicts with the concrete C API, please rename";
return mlir::failure();
}
}
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
->template hasTrait<mlir::OpTrait::FunctionLike>());
return mlir::success();
}

View File

@@ -6,6 +6,7 @@
#include "concretelang/Runtime/wrappers.h"
#include <assert.h>
#include <stdio.h>
#include <string.h>
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
@@ -96,3 +97,20 @@ void memref_bootstrap_lwe_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset,
glwe_ct_aligned + glwe_ct_offset);
}
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
uint64_t src_offset, uint64_t src_size,
uint64_t src_stride, uint64_t *dst_allocated,
uint64_t *dst_aligned, uint64_t dst_offset,
uint64_t dst_size, uint64_t dst_stride) {
assert(src_size == dst_size && "memref_copy_one_rank size differs");
if (src_stride == dst_stride) {
memcpy(dst_aligned + dst_offset, src_aligned + src_offset,
src_size * sizeof(uint64_t));
return;
}
for (size_t i = 0; i < src_size; i++) {
dst_aligned[dst_offset + i * dst_stride] =
src_aligned[src_offset + i * src_stride];
}
}