From 8cc0af1220b744489e6f93317225a1aa262774ff Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Thu, 9 Sep 2021 20:13:02 +0200 Subject: [PATCH] fix(compiler): Add a pass to unparametrize LowLFHE ciphertext to remove unrelized_convesrion_cast for linalg bufferization --- .../Conversion/LowLFHEUnparametrize/Pass.h | 14 ++ compiler/include/zamalang/Conversion/Passes.h | 1 + .../include/zamalang/Conversion/Passes.td | 6 + compiler/lib/Conversion/CMakeLists.txt | 1 + .../LowLFHEUnparametrize/CMakeLists.txt | 16 +++ .../LowLFHEUnparametrize.cpp | 129 ++++++++++++++++++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilerTools.cpp | 6 + 8 files changed, 174 insertions(+) create mode 100644 compiler/include/zamalang/Conversion/LowLFHEUnparametrize/Pass.h create mode 100644 compiler/lib/Conversion/LowLFHEUnparametrize/CMakeLists.txt create mode 100644 compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp diff --git a/compiler/include/zamalang/Conversion/LowLFHEUnparametrize/Pass.h b/compiler/include/zamalang/Conversion/LowLFHEUnparametrize/Pass.h new file mode 100644 index 000000000..632036689 --- /dev/null +++ b/compiler/include/zamalang/Conversion/LowLFHEUnparametrize/Pass.h @@ -0,0 +1,14 @@ + +#ifndef ZAMALANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_ +#define ZAMALANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace zamalang { +std::unique_ptr> +createConvertLowLFHEUnparametrizePass(); +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Conversion/Passes.h b/compiler/include/zamalang/Conversion/Passes.h index 524fc8985..644f59a2e 100644 --- a/compiler/include/zamalang/Conversion/Passes.h +++ b/compiler/include/zamalang/Conversion/Passes.h @@ -9,6 +9,7 @@ #include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h" #include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h" #include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h" +#include "zamalang/Conversion/LowLFHEUnparametrize/Pass.h" #include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" #include "zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h" #include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h" diff --git a/compiler/include/zamalang/Conversion/Passes.td b/compiler/include/zamalang/Conversion/Passes.td index f4c7f6cd5..87e8092f9 100644 --- a/compiler/include/zamalang/Conversion/Passes.td +++ b/compiler/include/zamalang/Conversion/Passes.td @@ -38,6 +38,12 @@ def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp"> let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; } +def LowLFHEUnparametrize : Pass<"lowlfhe-unparametrize", "mlir::ModuleOp"> { + let summary = "Unparametrize LowLFHE types and remove unrealized_conversion_cast"; + let constructor = "mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()"; + let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; +} + def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers operations from MLIR lowerable dialects to LLVM"; let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()"; diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index c6646bdc9..04d077129 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(MidLFHEToLowLFHE) add_subdirectory(HLFHETensorOpsToLinalg) add_subdirectory(LowLFHEToConcreteCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) +add_subdirectory(LowLFHEUnparametrize) diff --git a/compiler/lib/Conversion/LowLFHEUnparametrize/CMakeLists.txt b/compiler/lib/Conversion/LowLFHEUnparametrize/CMakeLists.txt new file mode 100644 index 000000000..6f86ec7d5 --- /dev/null +++ b/compiler/lib/Conversion/LowLFHEUnparametrize/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(LowLFHEUnparametrize + LowLFHEUnparametrize.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE + + DEPENDS + LowLFHEDialect + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms +) + +target_link_libraries(LowLFHEUnparametrize PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp new file mode 100644 index 000000000..b10a236ae --- /dev/null +++ b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp @@ -0,0 +1,129 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" +#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" + +/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize +/// LowLFHE types +class LowLFHEUnparametrizeTypeConverter : public mlir::TypeConverter { + +public: + static mlir::Type unparematrizeLowLFHEType(mlir::Type type) { + if (type.isa()) { + return mlir::IntegerType::get(type.getContext(), 64); + } + if (type.isa()) { + return mlir::IntegerType::get(type.getContext(), 64); + } + if (type.isa()) { + return mlir::zamalang::LowLFHE::LweCiphertextType::get(type.getContext(), + -1, -1); + } + auto tensorType = type.dyn_cast_or_null(); + if (tensorType != nullptr) { + auto eltTy0 = tensorType.getElementType(); + auto eltTy1 = unparematrizeLowLFHEType(eltTy0); + if (eltTy0 == eltTy1) { + return type; + } + return mlir::RankedTensorType::get(tensorType.getShape(), eltTy1); + } + return type; + } + + LowLFHEUnparametrizeTypeConverter() { + addConversion( + [](mlir::Type type) { return unparematrizeLowLFHEType(type); }); + } +}; + +/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or +/// t1 are a LowLFHE type. +struct LowLFHEUnrealizedCastReplacementPattern + : public mlir::OpRewritePattern { + LowLFHEUnrealizedCastReplacementPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, + benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::UnrealizedConversionCastOp op, + mlir::PatternRewriter &rewriter) const override { + if (mlir::isa( + op.getOperandTypes()[0].getDialect()) || + mlir::isa( + op.getType(0).getDialect())) { + rewriter.replaceOp(op, op.getOperands()); + return mlir::success(); + } + return mlir::failure(); + }; +}; + +/// LowLFHEUnparametrizePass remove all parameters of LowLFHE types and remove +/// the unrealized_conversion_cast operation that operates on parametrized +/// LowLFHE types. +struct LowLFHEUnparametrizePass + : public LowLFHEUnparametrizeBase { + void runOnOperation() final; +}; + +void LowLFHEUnparametrizePass::runOnOperation() { + auto op = this->getOperation(); + + mlir::ConversionTarget target(getContext()); + mlir::OwningRewritePatternList patterns(&getContext()); + + LowLFHEUnparametrizeTypeConverter converter; + + // Conversion of linalg.generic operation + target.addDynamicallyLegalOp( + [&](mlir::linalg::GenericOp op) { + return (converter.isLegal(op.getOperandTypes()) && + converter.isLegal(op.getResultTypes()) && + converter.isLegal(op->getRegion(0).front().getArgumentTypes())); + }); + patterns.add< + LinalgGenericTypeConverterPattern>( + &getContext(), converter); + + // Conversion of function signature and arguments + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); + }); + mlir::populateFuncOpTypeConversionPattern(patterns, converter); + + // Replacement of unrealized_conversion_cast + mlir::zamalang::addDynamicallyLegalTypeOp( + target, converter); + patterns.add(patterns.getContext()); + + // Conversion of tensor operators + mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, + converter); + + // Conversion of CallOp + patterns.add>( + patterns.getContext(), converter); + mlir::zamalang::addDynamicallyLegalTypeOp(target, converter); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { + this->signalPassFailure(); + } +} + +namespace mlir { +namespace zamalang { +std::unique_ptr> +createConvertLowLFHEUnparametrizePass() { + return std::make_unique(); +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 1f69d329b..0a611ab20 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_library(ZamalangSupport LINK_LIBS PUBLIC HLFHETensorOpsToLinalg HLFHEToMidLFHE + LowLFHEUnparametrize MLIRLowerableDialectsToLLVM MLIRExecutionEngine diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 6da29ac63..030812cca 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -89,6 +89,11 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( pm.enableVerifier(); } + // Unparametrize LowLFHE + addFilteredPassToPassManager( + pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(), + options.enablePass); + // Bufferize addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(), options.enablePass); @@ -105,6 +110,7 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(), options.enablePass); + // Convert to MLIR LLVM Dialect addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), options.enablePass);