mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler): add --parallelize-loops and --parallelize-dataflow compile flags in addition to --parallelize which enables both.
This commit is contained in:
@@ -145,7 +145,8 @@ public:
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(),
|
||||
clientParametersFuncName(), verifyDiagnostics(false),
|
||||
autoParallelize(false), generateClientParameters(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), generateClientParameters(false),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
@@ -170,6 +171,8 @@ public:
|
||||
void setMaxMANP(size_t v);
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setAutoParallelize(bool v);
|
||||
void setLoopParallelize(bool v);
|
||||
void setDataflowParallelize(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
|
||||
@@ -183,6 +186,8 @@ protected:
|
||||
|
||||
bool verifyDiagnostics;
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool dataflowParallelize;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
|
||||
@@ -96,6 +96,12 @@ void CompilerEngine::setVerifyDiagnostics(bool v) {
|
||||
|
||||
void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; }
|
||||
|
||||
void CompilerEngine::setLoopParallelize(bool v) { this->loopParallelize = v; }
|
||||
|
||||
void CompilerEngine::setDataflowParallelize(bool v) {
|
||||
this->dataflowParallelize = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setGenerateClientParameters(bool v) {
|
||||
this->generateClientParameters = v;
|
||||
}
|
||||
@@ -227,11 +233,11 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return errorDiag("Tiling of FHELinalg operations failed");
|
||||
}
|
||||
|
||||
// Auto parallelization
|
||||
if (this->autoParallelize &&
|
||||
// Dataflow parallelization
|
||||
if ((this->autoParallelize || this->dataflowParallelize) &&
|
||||
mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Auto parallelization failed");
|
||||
return StreamStringError("Dataflow parallelization failed");
|
||||
}
|
||||
|
||||
if (target == Target::FHE)
|
||||
@@ -298,7 +304,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
|
||||
mlirContext, module, enablePass,
|
||||
/*parallelizeLoops =*/this->autoParallelize)
|
||||
this->loopParallelize || this->autoParallelize)
|
||||
.failed()) {
|
||||
return errorDiag("Failed to lower to LLVM dialect");
|
||||
}
|
||||
|
||||
@@ -133,6 +133,18 @@ llvm::cl::opt<bool> autoParallelize(
|
||||
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> loopParallelize(
|
||||
"parallelize-loops",
|
||||
llvm::cl::desc(
|
||||
"Generate (and execute if JIT) parallel loops from Linalg operations"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> dataflowParallelize(
|
||||
"parallelize-dataflow",
|
||||
llvm::cl::desc(
|
||||
"Generate (and execute if JIT) the program as a dataflow graph"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
@@ -244,7 +256,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
llvm::Optional<llvm::ArrayRef<int64_t>> fhelinalgTileSizes,
|
||||
bool autoParallelize,
|
||||
bool autoParallelize, bool loopParallelize, bool dataflowParallelize,
|
||||
llvm::Optional<mlir::concretelang::KeySetCache> keySetCache,
|
||||
llvm::raw_ostream &os,
|
||||
std::shared_ptr<mlir::concretelang::CompilerEngine::Library> outputLib) {
|
||||
@@ -255,6 +267,8 @@ mlir::LogicalResult processInputBuffer(
|
||||
|
||||
ce.setVerifyDiagnostics(verifyDiagnostics);
|
||||
ce.setAutoParallelize(autoParallelize);
|
||||
ce.setLoopParallelize(loopParallelize);
|
||||
ce.setDataflowParallelize(dataflowParallelize);
|
||||
if (cmdline::passes.size() != 0) {
|
||||
ce.setEnablePass([](mlir::Pass *pass) {
|
||||
return std::any_of(
|
||||
@@ -429,8 +443,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
|
||||
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
|
||||
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
|
||||
fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os,
|
||||
outputLib);
|
||||
fhelinalgTileSizes, cmdline::autoParallelize,
|
||||
cmdline::loopParallelize, cmdline::dataflowParallelize,
|
||||
jitKeySetCache, os, outputLib);
|
||||
};
|
||||
auto &os = output->os();
|
||||
auto res = mlir::failure();
|
||||
|
||||
Reference in New Issue
Block a user