diff --git a/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h b/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h new file mode 100644 index 000000000..891121627 --- /dev/null +++ b/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h @@ -0,0 +1,16 @@ + +#ifndef ZAMALANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_ +#define ZAMALANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace zamalang { +/// Create a pass to convert `LowLFHE` operators to function call to the +/// `ConcreteCAPI` +std::unique_ptr> +createConvertLowLFHEToConcreteCAPIPass(); +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Conversion/Passes.h b/compiler/include/zamalang/Conversion/Passes.h index e90ed5758..822ffd108 100644 --- a/compiler/include/zamalang/Conversion/Passes.h +++ b/compiler/include/zamalang/Conversion/Passes.h @@ -8,9 +8,11 @@ #include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h" #include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h" +#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h" #include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" #include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #define GEN_PASS_CLASSES #include "zamalang/Conversion/Passes.h.inc" diff --git a/compiler/include/zamalang/Conversion/Passes.td b/compiler/include/zamalang/Conversion/Passes.td index f1f39f203..1781abcf6 100644 --- a/compiler/include/zamalang/Conversion/Passes.td +++ b/compiler/include/zamalang/Conversion/Passes.td @@ -25,6 +25,12 @@ def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect"]; } +def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp"> { + let summary = "Lower operations from the LowLFHE dialect to std with function call to the Concrete C API"; + let constructor = "mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()"; + let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; +} + def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers operations from MLIR lowerable dialects to LLVM"; let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()"; diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 9de41dd79..070482fac 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(HLFHEToMidLFHE) add_subdirectory(MidLFHEToLowLFHE) add_subdirectory(HLFHETensorOpsToLinalg) +add_subdirectory(LowLFHEToConcreteCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/CMakeLists.txt b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/CMakeLists.txt new file mode 100644 index 000000000..0ec258924 --- /dev/null +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(LowLFHEToConcreteCAPI + LowLFHEToConcreteCAPI.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE + + DEPENDS + LowLFHEDialect + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms +) + +target_link_libraries(LowLFHEToConcreteCAPI PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp new file mode 100644 index 000000000..ae2d07c56 --- /dev/null +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -0,0 +1,161 @@ +#include "mlir//IR/BuiltinTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "zamalang/Conversion/Passes.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" + +/// LowLFHEOpToConcreteCAPICallPattern match the `Op` Operation and +/// replace with a call to `funcName`, the funcName should be an external +/// function that was linked later. It insert the forward declaration of the +/// private `funcName` if it not already in the symbol table. +/// The C signature of the function should be `void funcName(int *err, out, +/// arg0, arg1)`, the pattern rewrite: +/// ``` +/// out = op(arg0, arg1) +/// ``` +/// to +/// ``` +/// err = memref.alloc() : memref +/// out = _allocate_(err); +/// call_op(err, out, arg0, arg1); +/// ``` +template +struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { + LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context, + mlir::StringRef funcName, + mlir::StringRef allocName, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), funcName(funcName), + allocName(allocName) {} + + mlir::LogicalResult static insertForwardDeclaration( + Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName, + mlir::FunctionType funcType) { + // Looking for the `funcName` Operation + auto module = mlir::SymbolTable::getNearestSymbolTable(op); + auto opFunc = mlir::dyn_cast_or_null( + mlir::SymbolTable::lookupSymbolIn(module, funcName)); + if (!opFunc) { + // Insert the forward declaration of the funcName + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + + opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, + funcType); + opFunc.setPrivate(); + } else { + // Check if the `funcName` is well a private function + if (!opFunc.isPrivate()) { + op.emitError() << "the function \"" << funcName + << "\" conflicts with the concrete C API, please rename"; + return mlir::failure(); + } + } + assert(mlir::SymbolTable::lookupSymbolIn(module, funcName) + ->template hasTrait()); + return mlir::success(); + } + + mlir::LogicalResult + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto errType = + mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext())); + // Insert forward declaration of the operator function + { + mlir::SmallVector operands{errType, + op->getResultTypes().front()}; + for (auto ty : op->getOperandTypes()) { + operands.push_back(ty); + } + auto funcType = + mlir::FunctionType::get(rewriter.getContext(), operands, {}); + if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the alloc function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {errType, rewriter.getIndexType()}, + {op->getResultTypes().front()}); + if (insertForwardDeclaration(op, rewriter, allocName, funcType) + .failed()) { + return mlir::failure(); + } + } + // Replace the operation with a call to the `funcName` + { + // Create the err value + auto err = rewriter.create(op.getLoc(), errType); + // Add the call to the allocation + // TODO - 2018 + auto lweSize = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(2048)); + mlir::SmallVector allocOperands{err, lweSize}; + auto alloc = rewriter.replaceOpWithNewOp( + op, allocName, op.getType(), allocOperands); + + // Add err and allocated value to operands + mlir::SmallVector newOperands{err, alloc.getResult(0)}; + for (auto operand : op->getOperands()) { + newOperands.push_back(operand); + } + rewriter.create(op.getLoc(), funcName, mlir::TypeRange{}, + newOperands); + } + return mlir::success(); + }; + +private: + std::string funcName; + std::string allocName; +}; + +/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE +/// operators to the corresponding function call to the `Concrete C API`. +void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext(), "add_lwe_ciphertexts_u64", + "allocate_lwe_ciphertext_u64"); +} + +namespace { +struct LowLFHEToConcreteCAPIPass + : public LowLFHEToConcreteCAPIBase { + void runOnOperation() final; +}; +} // namespace + +void LowLFHEToConcreteCAPIPass::runOnOperation() { + // Setup the conversion target. + mlir::ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + // Setup rewrite patterns + mlir::RewritePatternSet patterns(&getContext()); + populateLowLFHEToConcreteCAPICall(patterns); + + // Apply the conversion + mlir::ModuleOp op = getOperation(); + if (mlir::applyFullConversion(op, target, std::move(patterns)).failed()) { + this->signalPassFailure(); + } +} + +namespace mlir { +namespace zamalang { +std::unique_ptr> +createConvertLowLFHEToConcreteCAPIPass() { + return std::make_unique(); +} +} // namespace zamalang +} // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index 9ec6f2909..58c377478 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -5,6 +5,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" @@ -37,7 +38,9 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and // add our types conversion to `llvm` compatible type. - mlir::LLVMTypeConverter typeConverter(&getContext()); + mlir::LowerToLLVMOptions options(&getContext()); + options.useBarePtrCallConv = true; + mlir::LLVMTypeConverter typeConverter(&getContext(), options); typeConverter.addConversion(convertTypes); // Setup the set of the patterns rewriter. At this point we want to @@ -45,6 +48,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); mlir::populateLoopToStdConversionPatterns(patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns); // Apply a `FullConversion` to `llvm`. auto module = getOperation(); @@ -57,7 +61,7 @@ llvm::Optional MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { if (type.isa()) { return mlir::LLVM::LLVMPointerType::get( - mlir::IntegerType::get(type.getContext(), 8)); + mlir::IntegerType::get(type.getContext(), 64)); } return llvm::None; } diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index c03c7682a..c4780acb1 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -41,6 +41,8 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); + addFilteredPassToPassManager( + pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass); constraint = defaultGlobalFHECircuitConstraint; // Run the passes diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 9bfc47455..0c56fa967 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -253,6 +253,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); + context.getOrLoadDialect(); if (cmdline::verifyDiagnostics) context.printOpOnDiagnostic(false);