// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" #include "concretelang/Dialect/RT/IR/RTTypes.h" namespace { struct MLIRLowerableDialectsToLLVMPass : public MLIRLowerableDialectsToLLVMBase { void runOnOperation() final; /// Convert types to the LLVM dialect-compatible type static llvm::Optional convertTypes(mlir::Type type); }; } // namespace /// This rewrite pattern transforms any instance of `memref.copy` /// operators on 1D memref. /// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked /// memref that basically allocate unranked memref structure on the stack before /// calling @memrefCopy. /// /// Example: /// /// ```mlir /// memref.copy %src, %dst : memref to memref /// ``` /// /// becomes: /// /// ```mlir /// %_src = memref.cast %src = memref to memref /// %_dst = memref.cast %dst = memref to memref /// call @memref_copy_one_rank(%_src, %_dst) : (tensor, tensor) -> /// () /// ``` struct Memref1DCopyOpPattern : public mlir::OpRewritePattern { Memref1DCopyOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(mlir::memref::CopyOp copyOp, mlir::PatternRewriter &rewriter) const override { if (copyOp.source().getType().cast().getRank() != 1 || copyOp.source().getType().cast().getRank() != 1) { return mlir::failure(); } auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type()); // Insert forward declaration of the add_lwe_ciphertexts function { if (insertForwardDeclaration( copyOp, rewriter, "memref_copy_one_rank", mlir::FunctionType::get(rewriter.getContext(), {opType, opType}, {})) .failed()) { return mlir::failure(); } } auto sourceOp = rewriter.create( copyOp.getLoc(), opType, copyOp.source()); auto targetOp = rewriter.create( copyOp.getLoc(), opType, copyOp.target()); rewriter.replaceOpWithNewOp( copyOp, "memref_copy_one_rank", mlir::TypeRange{}, mlir::ValueRange{sourceOp, targetOp}); return mlir::success(); }; }; void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the conversion target. We reuse the LLVMConversionTarget that // legalize LLVM dialect. mlir::LLVMConversionTarget target(getContext()); target.addLegalOp(); target.addIllegalOp(); // Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and // add our types conversion to `llvm` compatible type. mlir::LowerToLLVMOptions options(&getContext()); mlir::LLVMTypeConverter typeConverter(&getContext(), options); typeConverter.addConversion(convertTypes); typeConverter.addConversion( [&](mlir::concretelang::Concrete::PlaintextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); typeConverter.addConversion( [&](mlir::concretelang::Concrete::CleartextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); // Setup the set of the patterns rewriter. At this point we want to // convert the `scf` operations to `std` and `std` operations to `llvm`. mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext(), 100); mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter, patterns); mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns); mlir::populateSCFToControlFlowConversionPatterns(patterns); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); target.addLegalOp(); mlir::populateOpenMPToLLVMConversionPatterns(typeConverter, patterns); target.addDynamicallyLegalOp([&](mlir::Operation *op) { return typeConverter.isLegal(&op->getRegion(0)); }); target.addLegalOp(); // Apply a `FullConversion` to `llvm`. auto module = getOperation(); if (mlir::applyFullConversion(module, target, std::move(patterns)).failed()) { signalPassFailure(); } } llvm::Optional MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { if (type.isa() || type.isa() || type.isa() || type.isa()) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(type.getContext(), 64)); } if (type.isa()) { mlir::LowerToLLVMOptions options(type.getContext()); mlir::LLVMTypeConverter typeConverter(type.getContext(), options); typeConverter.addConversion(convertTypes); typeConverter.addConversion( [&](mlir::concretelang::Concrete::PlaintextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); typeConverter.addConversion( [&](mlir::concretelang::Concrete::CleartextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); mlir::Type subtype = type.dyn_cast().getElementType(); mlir::Type convertedSubtype = typeConverter.convertType(subtype); return mlir::LLVM::LLVMPointerType::get(convertedSubtype); } return llvm::None; } namespace mlir { namespace concretelang { /// Create a pass for lowering operations the remaining mlir dialects /// operations, to the LLVM dialect for codegen. std::unique_ptr> createConvertMLIRLowerableDialectsToLLVMPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir