mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): First draft of lowering from LowLFHE to std with fct call (#62)
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
/// Create a pass to convert `LowLFHE` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -8,9 +8,11 @@
|
||||
|
||||
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
|
||||
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "zamalang/Conversion/Passes.h.inc"
|
||||
|
||||
@@ -25,6 +25,12 @@ def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp"> {
|
||||
let summary = "Lower operations from the LowLFHE dialect to std with function call to the Concrete C API";
|
||||
let constructor = "mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(MidLFHEToLowLFHE)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(LowLFHEToConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
|
||||
16
compiler/lib/Conversion/LowLFHEToConcreteCAPI/CMakeLists.txt
Normal file
16
compiler/lib/Conversion/LowLFHEToConcreteCAPI/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(LowLFHEToConcreteCAPI
|
||||
LowLFHEToConcreteCAPI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
|
||||
|
||||
DEPENDS
|
||||
LowLFHEDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(LowLFHEToConcreteCAPI PUBLIC MLIRIR)
|
||||
@@ -0,0 +1,161 @@
|
||||
#include "mlir//IR/BuiltinTypes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
|
||||
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
|
||||
/// replace with a call to `funcName`, the funcName should be an external
|
||||
/// function that was linked later. It insert the forward declaration of the
|
||||
/// private `funcName` if it not already in the symbol table.
|
||||
/// The C signature of the function should be `void funcName(int *err, out,
|
||||
/// arg0, arg1)`, the pattern rewrite:
|
||||
/// ```
|
||||
/// out = op(arg0, arg1)
|
||||
/// ```
|
||||
/// to
|
||||
/// ```
|
||||
/// err = memref.alloc() : memref<index>
|
||||
/// out = _allocate_(err);
|
||||
/// call_op(err, out, arg0, arg1);
|
||||
/// ```
|
||||
template <typename Op>
|
||||
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
|
||||
mlir::StringRef funcName,
|
||||
mlir::StringRef allocName,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
|
||||
allocName(allocName) {}
|
||||
|
||||
mlir::LogicalResult static insertForwardDeclaration(
|
||||
Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType) {
|
||||
// Looking for the `funcName` Operation
|
||||
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, funcName));
|
||||
if (!opFunc) {
|
||||
// Insert the forward declaration of the funcName
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
|
||||
funcType);
|
||||
opFunc.setPrivate();
|
||||
} else {
|
||||
// Check if the `funcName` is well a private function
|
||||
if (!opFunc.isPrivate()) {
|
||||
op.emitError() << "the function \"" << funcName
|
||||
<< "\" conflicts with the concrete C API, please rename";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
|
||||
->template hasTrait<mlir::OpTrait::FunctionLike>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
auto errType =
|
||||
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
|
||||
// Insert forward declaration of the operator function
|
||||
{
|
||||
mlir::SmallVector<mlir::Type, 4> operands{errType,
|
||||
op->getResultTypes().front()};
|
||||
for (auto ty : op->getOperandTypes()) {
|
||||
operands.push_back(ty);
|
||||
}
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), operands, {});
|
||||
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {errType, rewriter.getIndexType()},
|
||||
{op->getResultTypes().front()});
|
||||
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
// Create the err value
|
||||
auto err = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
|
||||
// Add the call to the allocation
|
||||
// TODO - 2018
|
||||
auto lweSize = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(2048));
|
||||
mlir::SmallVector<mlir::Value, 1> allocOperands{err, lweSize};
|
||||
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, allocName, op.getType(), allocOperands);
|
||||
|
||||
// Add err and allocated value to operands
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{err, alloc.getResult(0)};
|
||||
for (auto operand : op->getOperands()) {
|
||||
newOperands.push_back(operand);
|
||||
}
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
|
||||
newOperands);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::string funcName;
|
||||
std::string allocName;
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
|
||||
mlir::zamalang::LowLFHE::AddLweCiphertextsOp>>(
|
||||
patterns.getContext(), "add_lwe_ciphertexts_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowLFHEToConcreteCAPIPass
|
||||
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
// Setup the conversion target.
|
||||
mlir::ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
|
||||
mlir::memref::MemRefDialect>();
|
||||
|
||||
// Setup rewrite patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
populateLowLFHEToConcreteCAPICall(patterns);
|
||||
|
||||
// Apply the conversion
|
||||
mlir::ModuleOp op = getOperation();
|
||||
if (mlir::applyFullConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass() {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>();
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
@@ -37,7 +38,9 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
|
||||
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
|
||||
// add our types conversion to `llvm` compatible type.
|
||||
mlir::LLVMTypeConverter typeConverter(&getContext());
|
||||
mlir::LowerToLLVMOptions options(&getContext());
|
||||
options.useBarePtrCallConv = true;
|
||||
mlir::LLVMTypeConverter typeConverter(&getContext(), options);
|
||||
typeConverter.addConversion(convertTypes);
|
||||
|
||||
// Setup the set of the patterns rewriter. At this point we want to
|
||||
@@ -45,6 +48,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
mlir::populateLoopToStdConversionPatterns(patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Apply a `FullConversion` to `llvm`.
|
||||
auto module = getOperation();
|
||||
@@ -57,7 +61,7 @@ llvm::Optional<mlir::Type>
|
||||
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(type.getContext(), 8));
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass);
|
||||
constraint = defaultGlobalFHECircuitConstraint;
|
||||
|
||||
// Run the passes
|
||||
|
||||
@@ -253,6 +253,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
if (cmdline::verifyDiagnostics)
|
||||
context.printOpOnDiagnostic(false);
|
||||
|
||||
Reference in New Issue
Block a user