fix(compiler): Add a pass to unparametrize LowLFHE ciphertext to remove unrelized_convesrion_cast for linalg bufferization

This commit is contained in:
Quentin Bourgerie
2021-09-09 20:13:02 +02:00
parent d0e71dd4f1
commit 8cc0af1220
8 changed files with 174 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
#ifndef ZAMALANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
#define ZAMALANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEUnparametrizePass();
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -9,6 +9,7 @@
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
#include "zamalang/Conversion/LowLFHEUnparametrize/Pass.h"
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h"
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"

View File

@@ -38,6 +38,12 @@ def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp">
let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def LowLFHEUnparametrize : Pass<"lowlfhe-unparametrize", "mlir::ModuleOp"> {
let summary = "Unparametrize LowLFHE types and remove unrealized_conversion_cast";
let constructor = "mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()";
let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";

View File

@@ -4,3 +4,4 @@ add_subdirectory(MidLFHEToLowLFHE)
add_subdirectory(HLFHETensorOpsToLinalg)
add_subdirectory(LowLFHEToConcreteCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(LowLFHEUnparametrize)

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(LowLFHEUnparametrize
LowLFHEUnparametrize.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
DEPENDS
LowLFHEDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(LowLFHEUnparametrize PUBLIC MLIRIR)

View File

@@ -0,0 +1,129 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
/// LowLFHE types
class LowLFHEUnparametrizeTypeConverter : public mlir::TypeConverter {
public:
static mlir::Type unparematrizeLowLFHEType(mlir::Type type) {
if (type.isa<mlir::zamalang::LowLFHE::PlaintextType>()) {
return mlir::IntegerType::get(type.getContext(), 64);
}
if (type.isa<mlir::zamalang::LowLFHE::CleartextType>()) {
return mlir::IntegerType::get(type.getContext(), 64);
}
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
return mlir::zamalang::LowLFHE::LweCiphertextType::get(type.getContext(),
-1, -1);
}
auto tensorType = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensorType != nullptr) {
auto eltTy0 = tensorType.getElementType();
auto eltTy1 = unparematrizeLowLFHEType(eltTy0);
if (eltTy0 == eltTy1) {
return type;
}
return mlir::RankedTensorType::get(tensorType.getShape(), eltTy1);
}
return type;
}
LowLFHEUnparametrizeTypeConverter() {
addConversion(
[](mlir::Type type) { return unparematrizeLowLFHEType(type); });
}
};
/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or
/// t1 are a LowLFHE type.
struct LowLFHEUnrealizedCastReplacementPattern
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
LowLFHEUnrealizedCastReplacementPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::UnrealizedConversionCastOp op,
mlir::PatternRewriter &rewriter) const override {
if (mlir::isa<mlir::zamalang::LowLFHE::LowLFHEDialect>(
op.getOperandTypes()[0].getDialect()) ||
mlir::isa<mlir::zamalang::LowLFHE::LowLFHEDialect>(
op.getType(0).getDialect())) {
rewriter.replaceOp(op, op.getOperands());
return mlir::success();
}
return mlir::failure();
};
};
/// LowLFHEUnparametrizePass remove all parameters of LowLFHE types and remove
/// the unrealized_conversion_cast operation that operates on parametrized
/// LowLFHE types.
struct LowLFHEUnparametrizePass
: public LowLFHEUnparametrizeBase<LowLFHEUnparametrizePass> {
void runOnOperation() final;
};
void LowLFHEUnparametrizePass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
mlir::OwningRewritePatternList patterns(&getContext());
LowLFHEUnparametrizeTypeConverter converter;
// Conversion of linalg.generic operation
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
[&](mlir::linalg::GenericOp op) {
return (converter.isLegal(op.getOperandTypes()) &&
converter.isLegal(op.getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
patterns.add<
LinalgGenericTypeConverterPattern<LowLFHEUnparametrizeTypeConverter>>(
&getContext(), converter);
// Conversion of function signature and arguments
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Replacement of unrealized_conversion_cast
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::UnrealizedConversionCastOp>(
target, converter);
patterns.add<LowLFHEUnrealizedCastReplacementPattern>(patterns.getContext());
// Conversion of tensor operators
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
// Conversion of CallOp
patterns.add<mlir::zamalang::GenericTypeConverterPattern<mlir::CallOp>>(
patterns.getContext(), converter);
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::CallOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEUnparametrizePass() {
return std::make_unique<LowLFHEUnparametrizePass>();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -14,6 +14,7 @@ add_mlir_library(ZamalangSupport
LINK_LIBS PUBLIC
HLFHETensorOpsToLinalg
HLFHEToMidLFHE
LowLFHEUnparametrize
MLIRLowerableDialectsToLLVM
MLIRExecutionEngine

View File

@@ -89,6 +89,11 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
pm.enableVerifier();
}
// Unparametrize LowLFHE
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(),
options.enablePass);
// Bufferize
addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(),
options.enablePass);
@@ -105,6 +110,7 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(),
options.enablePass);
// Convert to MLIR LLVM Dialect
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
options.enablePass);