// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. // See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information. #include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" #include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/RT/Analysis/Autopar.h" #include "zamalang/Dialect/RT/IR/RTTypes.h" namespace { struct MLIRLowerableDialectsToLLVMPass : public MLIRLowerableDialectsToLLVMBase { void runOnOperation() final; /// Convert types to the LLVM dialect-compatible type static llvm::Optional convertTypes(mlir::Type type); }; } // namespace void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the conversion target. We reuse the LLVMConversionTarget that // legalize LLVM dialect. mlir::LLVMConversionTarget target(getContext()); target.addLegalOp(); target.addIllegalOp(); // Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and // add our types conversion to `llvm` compatible type. mlir::LowerToLLVMOptions options(&getContext()); mlir::LLVMTypeConverter typeConverter(&getContext(), options); typeConverter.addConversion(convertTypes); typeConverter.addConversion([&](mlir::zamalang::LowLFHE::PlaintextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); typeConverter.addConversion([&](mlir::zamalang::LowLFHE::CleartextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); // Setup the set of the patterns rewriter. At this point we want to // convert the `scf` operations to `std` and `std` operations to `llvm`. mlir::RewritePatternSet patterns(&getContext()); mlir::zamalang::populateRTToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns); // Apply a `FullConversion` to `llvm`. auto module = getOperation(); if (mlir::applyFullConversion(module, target, std::move(patterns)).failed()) { signalPassFailure(); } } llvm::Optional MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { if (type.isa() || type.isa() || type.isa() || type.isa() || type.isa() || type.isa() || type.isa() || type.isa()) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(type.getContext(), 64)); } if (type.isa()) { mlir::LowerToLLVMOptions options(type.getContext()); mlir::LLVMTypeConverter typeConverter(type.getContext(), options); typeConverter.addConversion(convertTypes); typeConverter.addConversion( [&](mlir::zamalang::LowLFHE::PlaintextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); typeConverter.addConversion( [&](mlir::zamalang::LowLFHE::CleartextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); mlir::Type subtype = type.dyn_cast().getElementType(); mlir::Type convertedSubtype = typeConverter.convertType(subtype); return mlir::LLVM::LLVMPointerType::get(convertedSubtype); } return llvm::None; } namespace mlir { namespace zamalang { /// Create a pass for lowering operations the remaining mlir dialects /// operations, to the LLVM dialect for codegen. std::unique_ptr> createConvertMLIRLowerableDialectsToLLVMPass() { return std::make_unique(); } } // namespace zamalang } // namespace mlir