Files
concrete/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp
2021-12-29 15:13:34 +01:00

116 lines
4.8 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#include <iostream>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/RT/Analysis/Autopar.h"
#include "zamalang/Dialect/RT/IR/RTTypes.h"
namespace {
struct MLIRLowerableDialectsToLLVMPass
: public MLIRLowerableDialectsToLLVMBase<MLIRLowerableDialectsToLLVMPass> {
void runOnOperation() final;
/// Convert types to the LLVM dialect-compatible type
static llvm::Optional<mlir::Type> convertTypes(mlir::Type type);
};
} // namespace
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the conversion target. We reuse the LLVMConversionTarget that
// legalize LLVM dialect.
mlir::LLVMConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addIllegalOp<mlir::UnrealizedConversionCastOp>();
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
// add our types conversion to `llvm` compatible type.
mlir::LowerToLLVMOptions options(&getContext());
mlir::LLVMTypeConverter typeConverter(&getContext(), options);
typeConverter.addConversion(convertTypes);
typeConverter.addConversion([&](mlir::zamalang::LowLFHE::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
typeConverter.addConversion([&](mlir::zamalang::LowLFHE::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
// Setup the set of the patterns rewriter. At this point we want to
// convert the `scf` operations to `std` and `std` operations to `llvm`.
mlir::RewritePatternSet patterns(&getContext());
mlir::zamalang::populateRTToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
// Apply a `FullConversion` to `llvm`.
auto module = getOperation();
if (mlir::applyFullConversion(module, target, std::move(patterns)).failed()) {
signalPassFailure();
}
}
llvm::Optional<mlir::Type>
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>() ||
type.isa<mlir::zamalang::LowLFHE::GlweCiphertextType>() ||
type.isa<mlir::zamalang::LowLFHE::LweKeySwitchKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>() ||
type.isa<mlir::zamalang::RT::FutureType>()) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(type.getContext(), 64));
}
if (type.isa<mlir::zamalang::RT::PointerType>()) {
mlir::LowerToLLVMOptions options(type.getContext());
mlir::LLVMTypeConverter typeConverter(type.getContext(), options);
typeConverter.addConversion(convertTypes);
typeConverter.addConversion(
[&](mlir::zamalang::LowLFHE::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
typeConverter.addConversion(
[&](mlir::zamalang::LowLFHE::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
mlir::Type subtype =
type.dyn_cast<mlir::zamalang::RT::PointerType>().getElementType();
mlir::Type convertedSubtype = typeConverter.convertType(subtype);
return mlir::LLVM::LLVMPointerType::get(convertedSubtype);
}
return llvm::None;
}
namespace mlir {
namespace zamalang {
/// Create a pass for lowering operations the remaining mlir dialects
/// operations, to the LLVM dialect for codegen.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertMLIRLowerableDialectsToLLVMPass() {
return std::make_unique<MLIRLowerableDialectsToLLVMPass>();
}
} // namespace zamalang
} // namespace mlir