From dddad849c79a8c9854033a3e99a6fcffadb93542 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Tue, 1 Feb 2022 15:22:07 +0000 Subject: [PATCH] enhance(compiler): add --parallelize-loops and --parallelize-dataflow compile flags in addition to --parallelize which enables both. --- .../concretelang/Support/CompilerEngine.h | 7 ++++++- compiler/lib/Support/CompilerEngine.cpp | 14 +++++++++---- compiler/src/main.cpp | 21 ++++++++++++++++--- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index b6d014244..da7f6f9bb 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -145,7 +145,8 @@ public: CompilerEngine(std::shared_ptr compilationContext) : overrideMaxEintPrecision(), overrideMaxMANP(), clientParametersFuncName(), verifyDiagnostics(false), - autoParallelize(false), generateClientParameters(false), + autoParallelize(false), loopParallelize(false), + dataflowParallelize(false), generateClientParameters(false), enablePass([](mlir::Pass *pass) { return true; }), compilationContext(compilationContext) {} @@ -170,6 +171,8 @@ public: void setMaxMANP(size_t v); void setVerifyDiagnostics(bool v); void setAutoParallelize(bool v); + void setLoopParallelize(bool v); + void setDataflowParallelize(bool v); void setGenerateClientParameters(bool v); void setClientParametersFuncName(const llvm::StringRef &name); void setFHELinalgTileSizes(llvm::ArrayRef sizes); @@ -183,6 +186,8 @@ protected: bool verifyDiagnostics; bool autoParallelize; + bool loopParallelize; + bool dataflowParallelize; bool generateClientParameters; std::function enablePass; diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 26806bf38..42b8eea55 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -96,6 +96,12 @@ void CompilerEngine::setVerifyDiagnostics(bool v) { void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; } +void CompilerEngine::setLoopParallelize(bool v) { this->loopParallelize = v; } + +void CompilerEngine::setDataflowParallelize(bool v) { + this->dataflowParallelize = v; +} + void CompilerEngine::setGenerateClientParameters(bool v) { this->generateClientParameters = v; } @@ -227,11 +233,11 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag("Tiling of FHELinalg operations failed"); } - // Auto parallelization - if (this->autoParallelize && + // Dataflow parallelization + if ((this->autoParallelize || this->dataflowParallelize) && mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass) .failed()) { - return StreamStringError("Auto parallelization failed"); + return StreamStringError("Dataflow parallelization failed"); } if (target == Target::FHE) @@ -298,7 +304,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { // MLIR canonical dialects -> LLVM Dialect if (mlir::concretelang::pipeline::lowerStdToLLVMDialect( mlirContext, module, enablePass, - /*parallelizeLoops =*/this->autoParallelize) + this->loopParallelize || this->autoParallelize) .failed()) { return errorDiag("Failed to lower to LLVM dialect"); } diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index ab435f519..74b1f83e0 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -133,6 +133,18 @@ llvm::cl::opt autoParallelize( llvm::cl::desc("Generate (and execute if JIT) parallel code"), llvm::cl::init(false)); +llvm::cl::opt loopParallelize( + "parallelize-loops", + llvm::cl::desc( + "Generate (and execute if JIT) parallel loops from Linalg operations"), + llvm::cl::init(false)); + +llvm::cl::opt dataflowParallelize( + "parallelize-dataflow", + llvm::cl::desc( + "Generate (and execute if JIT) the program as a dataflow graph"), + llvm::cl::init(false)); + llvm::cl::opt funcName("funcname", llvm::cl::desc("Name of the function to compile, default 'main'"), @@ -244,7 +256,7 @@ mlir::LogicalResult processInputBuffer( llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, llvm::Optional> fhelinalgTileSizes, - bool autoParallelize, + bool autoParallelize, bool loopParallelize, bool dataflowParallelize, llvm::Optional keySetCache, llvm::raw_ostream &os, std::shared_ptr outputLib) { @@ -255,6 +267,8 @@ mlir::LogicalResult processInputBuffer( ce.setVerifyDiagnostics(verifyDiagnostics); ce.setAutoParallelize(autoParallelize); + ce.setLoopParallelize(loopParallelize); + ce.setDataflowParallelize(dataflowParallelize); if (cmdline::passes.size() != 0) { ce.setEnablePass([](mlir::Pass *pass) { return std::any_of( @@ -429,8 +443,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName, cmdline::jitArgs, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, - fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os, - outputLib); + fhelinalgTileSizes, cmdline::autoParallelize, + cmdline::loopParallelize, cmdline::dataflowParallelize, + jitKeySetCache, os, outputLib); }; auto &os = output->os(); auto res = mlir::failure();