enhance(compiler): add --parallelize-loops and --parallelize-dataflow compile flags in addition to --parallelize which enables both.

This commit is contained in:
Antoniu Pop
2022-02-01 15:22:07 +00:00
committed by Antoniu Pop
parent 35e6966f95
commit dddad849c7
3 changed files with 34 additions and 8 deletions

View File

@@ -145,7 +145,8 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
autoParallelize(false), generateClientParameters(false),
autoParallelize(false), loopParallelize(false),
dataflowParallelize(false), generateClientParameters(false),
enablePass([](mlir::Pass *pass) { return true; }),
compilationContext(compilationContext) {}
@@ -170,6 +171,8 @@ public:
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setAutoParallelize(bool v);
void setLoopParallelize(bool v);
void setDataflowParallelize(bool v);
void setGenerateClientParameters(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
void setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
@@ -183,6 +186,8 @@ protected:
bool verifyDiagnostics;
bool autoParallelize;
bool loopParallelize;
bool dataflowParallelize;
bool generateClientParameters;
std::function<bool(mlir::Pass *)> enablePass;

View File

@@ -96,6 +96,12 @@ void CompilerEngine::setVerifyDiagnostics(bool v) {
void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; }
void CompilerEngine::setLoopParallelize(bool v) { this->loopParallelize = v; }
void CompilerEngine::setDataflowParallelize(bool v) {
this->dataflowParallelize = v;
}
void CompilerEngine::setGenerateClientParameters(bool v) {
this->generateClientParameters = v;
}
@@ -227,11 +233,11 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return errorDiag("Tiling of FHELinalg operations failed");
}
// Auto parallelization
if (this->autoParallelize &&
// Dataflow parallelization
if ((this->autoParallelize || this->dataflowParallelize) &&
mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass)
.failed()) {
return StreamStringError("Auto parallelization failed");
return StreamStringError("Dataflow parallelization failed");
}
if (target == Target::FHE)
@@ -298,7 +304,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// MLIR canonical dialects -> LLVM Dialect
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
mlirContext, module, enablePass,
/*parallelizeLoops =*/this->autoParallelize)
this->loopParallelize || this->autoParallelize)
.failed()) {
return errorDiag("Failed to lower to LLVM dialect");
}

View File

@@ -133,6 +133,18 @@ llvm::cl::opt<bool> autoParallelize(
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
llvm::cl::init(false));
llvm::cl::opt<bool> loopParallelize(
"parallelize-loops",
llvm::cl::desc(
"Generate (and execute if JIT) parallel loops from Linalg operations"),
llvm::cl::init(false));
llvm::cl::opt<bool> dataflowParallelize(
"parallelize-dataflow",
llvm::cl::desc(
"Generate (and execute if JIT) the program as a dataflow graph"),
llvm::cl::init(false));
llvm::cl::opt<std::string>
funcName("funcname",
llvm::cl::desc("Name of the function to compile, default 'main'"),
@@ -244,7 +256,7 @@ mlir::LogicalResult processInputBuffer(
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
llvm::Optional<llvm::ArrayRef<int64_t>> fhelinalgTileSizes,
bool autoParallelize,
bool autoParallelize, bool loopParallelize, bool dataflowParallelize,
llvm::Optional<mlir::concretelang::KeySetCache> keySetCache,
llvm::raw_ostream &os,
std::shared_ptr<mlir::concretelang::CompilerEngine::Library> outputLib) {
@@ -255,6 +267,8 @@ mlir::LogicalResult processInputBuffer(
ce.setVerifyDiagnostics(verifyDiagnostics);
ce.setAutoParallelize(autoParallelize);
ce.setLoopParallelize(loopParallelize);
ce.setDataflowParallelize(dataflowParallelize);
if (cmdline::passes.size() != 0) {
ce.setEnablePass([](mlir::Pass *pass) {
return std::any_of(
@@ -429,8 +443,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os,
outputLib);
fhelinalgTileSizes, cmdline::autoParallelize,
cmdline::loopParallelize, cmdline::dataflowParallelize,
jitKeySetCache, os, outputLib);
};
auto &os = output->os();
auto res = mlir::failure();