feat(compiler): First draft of lowering from LowLFHE to std with fct call (#62)

This commit is contained in:
Quentin Bourgerie
2021-08-11 10:11:36 +02:00
parent d0877536ed
commit b22f585380
9 changed files with 211 additions and 2 deletions

View File

@@ -0,0 +1,16 @@
#ifndef ZAMALANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
#define ZAMALANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace zamalang {
/// Create a pass to convert `LowLFHE` operators to function call to the
/// `ConcreteCAPI`
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass();
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -8,9 +8,11 @@
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#define GEN_PASS_CLASSES
#include "zamalang/Conversion/Passes.h.inc"

View File

@@ -25,6 +25,12 @@ def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp"> {
let summary = "Lower operations from the LowLFHE dialect to std with function call to the Concrete C API";
let constructor = "mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()";
let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";

View File

@@ -1,4 +1,5 @@
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(MidLFHEToLowLFHE)
add_subdirectory(HLFHETensorOpsToLinalg)
add_subdirectory(LowLFHEToConcreteCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(LowLFHEToConcreteCAPI
LowLFHEToConcreteCAPI.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
DEPENDS
LowLFHEDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(LowLFHEToConcreteCAPI PUBLIC MLIRIR)

View File

@@ -0,0 +1,161 @@
#include "mlir//IR/BuiltinTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
/// replace with a call to `funcName`, the funcName should be an external
/// function that was linked later. It insert the forward declaration of the
/// private `funcName` if it not already in the symbol table.
/// The C signature of the function should be `void funcName(int *err, out,
/// arg0, arg1)`, the pattern rewrite:
/// ```
/// out = op(arg0, arg1)
/// ```
/// to
/// ```
/// err = memref.alloc() : memref<index>
/// out = _allocate_(err);
/// call_op(err, out, arg0, arg1);
/// ```
template <typename Op>
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
mlir::StringRef funcName,
mlir::StringRef allocName,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
allocName(allocName) {}
mlir::LogicalResult static insertForwardDeclaration(
Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
mlir::SymbolTable::lookupSymbolIn(module, funcName));
if (!opFunc) {
// Insert the forward declaration of the funcName
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcType);
opFunc.setPrivate();
} else {
// Check if the `funcName` is well a private function
if (!opFunc.isPrivate()) {
op.emitError() << "the function \"" << funcName
<< "\" conflicts with the concrete C API, please rename";
return mlir::failure();
}
}
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
->template hasTrait<mlir::OpTrait::FunctionLike>());
return mlir::success();
}
mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
auto errType =
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
// Insert forward declaration of the operator function
{
mlir::SmallVector<mlir::Type, 4> operands{errType,
op->getResultTypes().front()};
for (auto ty : op->getOperandTypes()) {
operands.push_back(ty);
}
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), operands, {});
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {errType, rewriter.getIndexType()},
{op->getResultTypes().front()});
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
.failed()) {
return mlir::failure();
}
}
// Replace the operation with a call to the `funcName`
{
// Create the err value
auto err = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
// Add the call to the allocation
// TODO - 2018
auto lweSize = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(2048));
mlir::SmallVector<mlir::Value, 1> allocOperands{err, lweSize};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
// Add err and allocated value to operands
mlir::SmallVector<mlir::Value, 4> newOperands{err, alloc.getResult(0)};
for (auto operand : op->getOperands()) {
newOperands.push_back(operand);
}
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
newOperands);
}
return mlir::success();
};
private:
std::string funcName;
std::string allocName;
};
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
/// operators to the corresponding function call to the `Concrete C API`.
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::zamalang::LowLFHE::AddLweCiphertextsOp>>(
patterns.getContext(), "add_lwe_ciphertexts_u64",
"allocate_lwe_ciphertext_u64");
}
namespace {
struct LowLFHEToConcreteCAPIPass
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
void runOnOperation() final;
};
} // namespace
void LowLFHEToConcreteCAPIPass::runOnOperation() {
// Setup the conversion target.
mlir::ConversionTarget target(getContext());
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::memref::MemRefDialect>();
// Setup rewrite patterns
mlir::RewritePatternSet patterns(&getContext());
populateLowLFHEToConcreteCAPICall(patterns);
// Apply the conversion
mlir::ModuleOp op = getOperation();
if (mlir::applyFullConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass() {
return std::make_unique<LowLFHEToConcreteCAPIPass>();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -5,6 +5,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
@@ -37,7 +38,9 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
// add our types conversion to `llvm` compatible type.
mlir::LLVMTypeConverter typeConverter(&getContext());
mlir::LowerToLLVMOptions options(&getContext());
options.useBarePtrCallConv = true;
mlir::LLVMTypeConverter typeConverter(&getContext(), options);
typeConverter.addConversion(convertTypes);
// Setup the set of the patterns rewriter. At this point we want to
@@ -45,6 +48,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
mlir::populateLoopToStdConversionPatterns(patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
// Apply a `FullConversion` to `llvm`.
auto module = getOperation();
@@ -57,7 +61,7 @@ llvm::Optional<mlir::Type>
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(type.getContext(), 8));
mlir::IntegerType::get(type.getContext(), 64));
}
return llvm::None;
}

View File

@@ -41,6 +41,8 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass);
constraint = defaultGlobalFHECircuitConstraint;
// Run the passes

View File

@@ -253,6 +253,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
if (cmdline::verifyDiagnostics)
context.printOpOnDiagnostic(false);