fix(compiler): Activate the loop parallelism in ConcreteToBConcrete when the option is set

This commit is contained in:
Quentin Bourgerie
2022-03-24 17:05:42 +01:00
parent d8aa9ff76b
commit fc51b1d2ab
6 changed files with 29 additions and 13 deletions

View File

@@ -11,7 +11,8 @@
namespace mlir {
namespace concretelang {
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToBConcretePass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteToBConcretePass(bool loopParallelize);
} // namespace concretelang
} // namespace mlir

View File

@@ -36,7 +36,6 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
let summary = "Lowers operations from the Concrete dialect to Bufferized Concrete";
let description = [{ Lowers operations from the Concrete dialect to Bufferized Concrete }];
let constructor = "mlir::concretelang::createConvertConcreteToBConcretePass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
}

View File

@@ -44,7 +44,8 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops);
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,

View File

@@ -24,6 +24,12 @@ namespace {
struct ConcreteToBConcretePass
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
void runOnOperation() final;
ConcreteToBConcretePass() = delete;
ConcreteToBConcretePass(bool loopParallelize)
: loopParallelize(loopParallelize){};
private:
bool loopParallelize;
};
} // namespace
@@ -953,12 +959,17 @@ void ConcreteToBConcretePass::runOnOperation() {
// Add patterns to rewrite linalg op to nested loops with views on
// ciphertexts
patterns.insert<LinalgRewritePattern<mlir::scf::ForOp>>(converter,
&getContext());
if (loopParallelize) {
patterns.insert<LinalgRewritePattern<mlir::scf::ParallelOp>>(
converter, &getContext());
} else {
patterns.insert<LinalgRewritePattern<mlir::scf::ForOp>>(converter,
&getContext());
}
target.addLegalOp<mlir::arith::ConstantOp, mlir::scf::ForOp,
mlir::scf::YieldOp, mlir::AffineApplyOp,
mlir::memref::SubViewOp, mlir::memref::LoadOp,
mlir::memref::TensorStoreOp>();
mlir::scf::ParallelOp, mlir::scf::YieldOp,
mlir::AffineApplyOp, mlir::memref::SubViewOp,
mlir::memref::LoadOp, mlir::memref::TensorStoreOp>();
// Add patterns to do the conversion of func
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
@@ -991,8 +1002,8 @@ void ConcreteToBConcretePass::runOnOperation() {
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteToBConcretePass() {
return std::make_unique<ConcreteToBConcretePass>();
createConvertConcreteToBConcretePass(bool loopParallelize) {
return std::make_unique<ConcreteToBConcretePass>(loopParallelize);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -299,7 +299,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// Concrete -> BConcrete
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
mlirContext, module, this->enablePass)
mlirContext, module, this->enablePass,
this->loopParallelize || this->autoParallelize)
.failed()) {
return StreamStringError(
"Lowering from Concrete to Bufferized Concrete failed");

View File

@@ -187,11 +187,14 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteToBConcrete", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertConcreteToBConcretePass(),
pm,
mlir::concretelang::createConvertConcreteToBConcretePass(
parallelizeLoops),
enablePass);
return pm.run(module.getOperation());