feat(compiler): Integrate HLFHELinalg tiling passes into compilation pipeline

This commit is contained in:
Andi Drebes
2021-12-15 16:56:16 +01:00
parent bc75831c86
commit f319ba37d2
10 changed files with 79 additions and 6 deletions

View File

@@ -148,12 +148,15 @@ public:
void setVerifyDiagnostics(bool v);
void setGenerateClientParameters(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
void setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
llvm::Optional<size_t> overrideMaxMANP;
llvm::Optional<std::string> clientParametersFuncName;
llvm::Optional<std::vector<int64_t>> hlfhelinalgTileSizes;
bool verifyDiagnostics;
bool generateClientParameters;
std::function<bool(mlir::Pass *)> enablePass;

View File

@@ -16,6 +16,15 @@ llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::ArrayRef<int64_t> tileSizes,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

View File

@@ -84,6 +84,10 @@ void HLFHEToMidLFHEPass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
HLFHEToMidLFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
HLFHEToMidLFHETypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);

View File

@@ -98,6 +98,9 @@ void LowLFHEUnparametrizePass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
LowLFHEUnparametrizeTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
LowLFHEUnparametrizeTypeConverter>>(
&getContext(), converter);
// Conversion of function signature and arguments
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {

View File

@@ -294,6 +294,9 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<
mlir::tensor::GenerateOp, MidLFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<
mlir::scf::ForOp, MidLFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);

View File

@@ -58,14 +58,13 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
// Make sure that no ops `linalg.generic` that have illegal types
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
mlir::tensor::GenerateOp, mlir::scf::ForOp>(
[&](mlir::Operation *op) {
return (converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
@@ -83,6 +82,9 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
MidLFHEToLowLFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
MidLFHEToLowLFHETypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -22,6 +22,7 @@ add_mlir_library(ZamalangSupport
LINK_LIBS PUBLIC
HLFHELinalgDialect
HLFHELinalgDialectTransforms
HLFHETensorOpsToLinalg
HLFHEToMidLFHE
LowLFHEUnparametrize

View File

@@ -5,6 +5,7 @@
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/SCF.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
@@ -53,6 +54,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
this->mlirContext->getOrLoadDialect<mlir::memref::MemRefDialect>();
this->mlirContext->getOrLoadDialect<mlir::linalg::LinalgDialect>();
this->mlirContext->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
this->mlirContext->getOrLoadDialect<mlir::scf::SCFDialect>();
}
return this->mlirContext;
@@ -94,6 +96,10 @@ void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
this->clientParametersFuncName = name.str();
}
void CompilerEngine::setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes) {
this->hlfhelinalgTileSizes = sizes.vec();
}
void CompilerEngine::setEnablePass(
std::function<bool(mlir::Pass *)> enablePass) {
this->enablePass = enablePass;
@@ -192,6 +198,21 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// HLFHE High level pass to determine FHE parameters
if (auto err = this->determineFHEParameters(res))
return std::move(err);
// HLFHELinalg tiling
if (this->hlfhelinalgTileSizes) {
if (mlir::zamalang::pipeline::markHLFHELinalgForTiling(
mlirContext, module, *this->hlfhelinalgTileSizes, enablePass)
.failed())
return errorDiag("Marking of HLFHELinalg operations for tiling failed");
}
if (mlir::zamalang::pipeline::tileMarkedHLFHELinalg(mlirContext, module,
enablePass)
.failed()) {
return errorDiag("Tiling of HLFHELinalg operations failed");
}
if (target == Target::HLFHE)
return std::move(res);

View File

@@ -3,6 +3,7 @@
#include <llvm/Support/Error.h>
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h>
#include <mlir/Dialect/Linalg/Passes.h>
#include <mlir/Dialect/SCF/Passes.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
#include <mlir/Dialect/Tensor/Transforms/Passes.h>
#include <mlir/ExecutionEngine/OptUtils.h>
@@ -13,6 +14,7 @@
#include <zamalang/Conversion/Passes.h>
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h>
#include <zamalang/Support/Pipeline.h>
#include <zamalang/Support/logging.h>
#include <zamalang/Support/math.h>
@@ -100,6 +102,29 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
return ret;
}
mlir::LogicalResult
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("TileMarkedHLFHELinalg", pm, context);
addPotentiallyNestedPass(pm, mlir::zamalang::createHLFHELinalgTilingPass(),
enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::ArrayRef<int64_t> tileSizes,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("MarkHLFHELinalgForTiling", pm, context);
addPotentiallyNestedPass(pm, createHLFHELinalgTilingMarkerPass(tileSizes),
enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
@@ -163,6 +188,7 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createSCFBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
enablePass);