mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): Custom copy op from 1D tensor to avoir stack allocation from mlir memref to llvm lowering
This commit is contained in:
11
compiler/include/concretelang/Conversion/Tools.h
Normal file
11
compiler/include/concretelang/Conversion/Tools.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::RewriterBase &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType);
|
||||
@@ -57,6 +57,12 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
uint64_t src_offset, uint64_t src_size,
|
||||
uint64_t src_stride, uint64_t *dst_allocated,
|
||||
uint64_t *dst_aligned, uint64_t dst_offset,
|
||||
uint64_t dst_size, uint64_t dst_stride);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
@@ -34,35 +35,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::RewriterBase &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType) {
|
||||
// Looking for the `funcName` Operation
|
||||
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, funcName));
|
||||
if (!opFunc) {
|
||||
// Insert the forward declaration of the funcName
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
|
||||
funcType);
|
||||
opFunc.setPrivate();
|
||||
} else {
|
||||
// Check if the `funcName` is well a private function
|
||||
if (!opFunc.isPrivate()) {
|
||||
op->emitError() << "the function \"" << funcName
|
||||
<< "\" conflicts with the concrete C API, please rename";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
|
||||
->template hasTrait<mlir::OpTrait::FunctionLike>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Set of functions to generate generic types.
|
||||
// Generic types are used to add forward declarations without a specific type.
|
||||
// For example, we may need to add LWE ciphertext of different dimensions, or
|
||||
|
||||
@@ -9,6 +9,7 @@ add_mlir_dialect_library(BConcreteToBConcreteCAPI
|
||||
mlir-headers
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
ConcretelangConversion
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
@@ -5,3 +5,10 @@ add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(BConcreteToBConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
|
||||
add_mlir_library(ConcretelangConversion
|
||||
Tools.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
||||
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRLowerableDialectsToLLVM
|
||||
LINK_LIBS PUBLIC
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
ConcretelangConversion
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRLLVMIR
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/Analysis/Autopar.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
@@ -39,6 +40,61 @@ struct MLIRLowerableDialectsToLLVMPass
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This rewrite pattern transforms any instance of `memref.copy`
|
||||
// operators on 1D memref.
|
||||
// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked
|
||||
// memref that basically allocate unranked memref structure on the stack before
|
||||
// calling @memrefCopy.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// memref.copy %src, %dst : memref<Xxi64> to memref<Xxi64>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %_src = memref.cast %src = memref<Xxi64> to memref<?xi64>
|
||||
// %_dst = memref.cast %dst = memref<Xxi64> to memref<?xi64>
|
||||
// call @memref_copy_one_rank(%_src, %_dst) : (tensor<?xi64>, tensor<?xi64>) ->
|
||||
// ()
|
||||
// ```
|
||||
struct Memref1DCopyOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
|
||||
Memref1DCopyOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::memref::CopyOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::memref::CopyOp copyOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1 ||
|
||||
copyOp.source().getType().cast<mlir::MemRefType>().getRank() != 1) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type());
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
if (insertForwardDeclaration(
|
||||
copyOp, rewriter, "memref_copy_one_rank",
|
||||
mlir::FunctionType::get(rewriter.getContext(), {opType, opType},
|
||||
{}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
auto sourceOp = rewriter.create<mlir::memref::CastOp>(
|
||||
copyOp.getLoc(), opType, copyOp.source());
|
||||
auto targetOp = rewriter.create<mlir::memref::CastOp>(
|
||||
copyOp.getLoc(), opType, copyOp.target());
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
copyOp, "memref_copy_one_rank", mlir::TypeRange{},
|
||||
mlir::ValueRange{sourceOp, targetOp});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the conversion target. We reuse the LLVMConversionTarget that
|
||||
// legalize LLVM dialect.
|
||||
@@ -63,6 +119,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the set of the patterns rewriter. At this point we want to
|
||||
// convert the `scf` operations to `std` and `std` operations to `llvm`.
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<Memref1DCopyOpPattern>(&getContext(), 100);
|
||||
mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
35
compiler/lib/Conversion/Tools.cpp
Normal file
35
compiler/lib/Conversion/Tools.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::RewriterBase &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType) {
|
||||
// Looking for the `funcName` Operation
|
||||
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, funcName));
|
||||
if (!opFunc) {
|
||||
// Insert the forward declaration of the funcName
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
|
||||
funcType);
|
||||
opFunc.setPrivate();
|
||||
} else {
|
||||
// Check if the `funcName` is well a private function
|
||||
if (!opFunc.isPrivate()) {
|
||||
op->emitError() << "the function \"" << funcName
|
||||
<< "\" conflicts with the concrete C API, please rename";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
|
||||
->template hasTrait<mlir::OpTrait::FunctionLike>());
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
@@ -96,3 +97,20 @@ void memref_bootstrap_lwe_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset,
|
||||
glwe_ct_aligned + glwe_ct_offset);
|
||||
}
|
||||
|
||||
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
uint64_t src_offset, uint64_t src_size,
|
||||
uint64_t src_stride, uint64_t *dst_allocated,
|
||||
uint64_t *dst_aligned, uint64_t dst_offset,
|
||||
uint64_t dst_size, uint64_t dst_stride) {
|
||||
assert(src_size == dst_size && "memref_copy_one_rank size differs");
|
||||
if (src_stride == dst_stride) {
|
||||
memcpy(dst_aligned + dst_offset, src_aligned + src_offset,
|
||||
src_size * sizeof(uint64_t));
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < src_size; i++) {
|
||||
dst_aligned[dst_offset + i * dst_stride] =
|
||||
src_aligned[src_offset + i * src_stride];
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user