From f319ba37d249a4c08ad365fe56f5fd118c0acac7 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 15 Dec 2021 16:56:16 +0100 Subject: [PATCH] feat(compiler): Integrate HLFHELinalg tiling passes into compilation pipeline --- .../include/zamalang/Support/CompilerEngine.h | 3 +++ compiler/include/zamalang/Support/Pipeline.h | 9 +++++++ .../HLFHEToMidLFHE/HLFHEToMidLFHE.cpp | 4 +++ .../LowLFHEUnparametrize.cpp | 3 +++ .../MidLFHEGlobalParametrization.cpp | 3 +++ .../MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp | 14 +++++----- .../lib/Dialect/HLFHELinalg/CMakeLists.txt | 1 + compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilerEngine.cpp | 21 +++++++++++++++ compiler/lib/Support/Pipeline.cpp | 26 +++++++++++++++++++ 10 files changed, 79 insertions(+), 6 deletions(-) diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index 62ecd4a27..354107b91 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -148,12 +148,15 @@ public: void setVerifyDiagnostics(bool v); void setGenerateClientParameters(bool v); void setClientParametersFuncName(const llvm::StringRef &name); + void setHLFHELinalgTileSizes(llvm::ArrayRef sizes); void setEnablePass(std::function enablePass); protected: llvm::Optional overrideMaxEintPrecision; llvm::Optional overrideMaxMANP; llvm::Optional clientParametersFuncName; + llvm::Optional> hlfhelinalgTileSizes; + bool verifyDiagnostics; bool generateClientParameters; std::function enablePass; diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h index 0a94e4f70..3bab5ee85 100644 --- a/compiler/include/zamalang/Support/Pipeline.h +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -16,6 +16,15 @@ llvm::Expected> getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult +tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + +mlir::LogicalResult +markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module, + llvm::ArrayRef tileSizes, + std::function enablePass); + mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp index 355e39cb4..d75d1067a 100644 --- a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp +++ b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp @@ -84,6 +84,10 @@ void HLFHEToMidLFHEPass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); + mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); diff --git a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp index cb8d63899..32478330c 100644 --- a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp +++ b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp @@ -98,6 +98,9 @@ void LowLFHEUnparametrizePass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); // Conversion of function signature and arguments target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index c203c67cc..ec6aabbcc 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -294,6 +294,9 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index 3d6cef52f..934bf0cf5 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -58,14 +58,13 @@ void MidLFHEToLowLFHEPass::runOnOperation() { target.addIllegalDialect(); // Make sure that no ops `linalg.generic` that have illegal types - target - .addDynamicallyLegalOp( - [&](mlir::Operation *op) { - return ( - converter.isLegal(op->getOperandTypes()) && + target.addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return (converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); - }); + }); // Make sure that func has legal signature target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { @@ -83,6 +82,9 @@ void MidLFHEToLowLFHEPass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); diff --git a/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt b/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt +++ b/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 51632a41c..098b83b56 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -22,6 +22,7 @@ add_mlir_library(ZamalangSupport LINK_LIBS PUBLIC HLFHELinalgDialect + HLFHELinalgDialectTransforms HLFHETensorOpsToLinalg HLFHEToMidLFHE LowLFHEUnparametrize diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index f1a7358ab..d5869ca48 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -53,6 +54,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { this->mlirContext->getOrLoadDialect(); this->mlirContext->getOrLoadDialect(); this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); } return this->mlirContext; @@ -94,6 +96,10 @@ void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) { this->clientParametersFuncName = name.str(); } +void CompilerEngine::setHLFHELinalgTileSizes(llvm::ArrayRef sizes) { + this->hlfhelinalgTileSizes = sizes.vec(); +} + void CompilerEngine::setEnablePass( std::function enablePass) { this->enablePass = enablePass; @@ -192,6 +198,21 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { // HLFHE High level pass to determine FHE parameters if (auto err = this->determineFHEParameters(res)) return std::move(err); + + // HLFHELinalg tiling + if (this->hlfhelinalgTileSizes) { + if (mlir::zamalang::pipeline::markHLFHELinalgForTiling( + mlirContext, module, *this->hlfhelinalgTileSizes, enablePass) + .failed()) + return errorDiag("Marking of HLFHELinalg operations for tiling failed"); + } + + if (mlir::zamalang::pipeline::tileMarkedHLFHELinalg(mlirContext, module, + enablePass) + .failed()) { + return errorDiag("Tiling of HLFHELinalg operations failed"); + } + if (target == Target::HLFHE) return std::move(res); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 295096a3d..4a7f774df 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ #include #include +#include #include #include #include @@ -100,6 +102,29 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, return ret; } +mlir::LogicalResult +tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("TileMarkedHLFHELinalg", pm, context); + addPotentiallyNestedPass(pm, mlir::zamalang::createHLFHELinalgTilingPass(), + enablePass); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult +markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module, + llvm::ArrayRef tileSizes, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("MarkHLFHELinalgForTiling", pm, context); + addPotentiallyNestedPass(pm, createHLFHELinalgTilingMarkerPass(tileSizes), + enablePass); + + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { @@ -163,6 +188,7 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass(), enablePass); addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(), enablePass); + addPotentiallyNestedPass(pm, mlir::createSCFBufferizePass(), enablePass); addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass); addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(), enablePass);