mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): Activate the loop parallelism in ConcreteToBConcrete when the option is set
This commit is contained in:
@@ -11,7 +11,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToBConcretePass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the Concrete dialect to Bufferized Concrete";
|
||||
let description = [{ Lowers operations from the Concrete dialect to Bufferized Concrete }];
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToBConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
@@ -24,6 +24,12 @@ namespace {
|
||||
struct ConcreteToBConcretePass
|
||||
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
|
||||
void runOnOperation() final;
|
||||
ConcreteToBConcretePass() = delete;
|
||||
ConcreteToBConcretePass(bool loopParallelize)
|
||||
: loopParallelize(loopParallelize){};
|
||||
|
||||
private:
|
||||
bool loopParallelize;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -953,12 +959,17 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
|
||||
// Add patterns to rewrite linalg op to nested loops with views on
|
||||
// ciphertexts
|
||||
patterns.insert<LinalgRewritePattern<mlir::scf::ForOp>>(converter,
|
||||
&getContext());
|
||||
if (loopParallelize) {
|
||||
patterns.insert<LinalgRewritePattern<mlir::scf::ParallelOp>>(
|
||||
converter, &getContext());
|
||||
} else {
|
||||
patterns.insert<LinalgRewritePattern<mlir::scf::ForOp>>(converter,
|
||||
&getContext());
|
||||
}
|
||||
target.addLegalOp<mlir::arith::ConstantOp, mlir::scf::ForOp,
|
||||
mlir::scf::YieldOp, mlir::AffineApplyOp,
|
||||
mlir::memref::SubViewOp, mlir::memref::LoadOp,
|
||||
mlir::memref::TensorStoreOp>();
|
||||
mlir::scf::ParallelOp, mlir::scf::YieldOp,
|
||||
mlir::AffineApplyOp, mlir::memref::SubViewOp,
|
||||
mlir::memref::LoadOp, mlir::memref::TensorStoreOp>();
|
||||
|
||||
// Add patterns to do the conversion of func
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
@@ -991,8 +1002,8 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass() {
|
||||
return std::make_unique<ConcreteToBConcretePass>();
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize) {
|
||||
return std::make_unique<ConcreteToBConcretePass>(loopParallelize);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -299,7 +299,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass)
|
||||
mlirContext, module, this->enablePass,
|
||||
this->loopParallelize || this->autoParallelize)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete to Bufferized Concrete failed");
|
||||
|
||||
@@ -187,11 +187,14 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToBConcrete", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteToBConcretePass(),
|
||||
pm,
|
||||
mlir::concretelang::createConvertConcreteToBConcretePass(
|
||||
parallelizeLoops),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
|
||||
Reference in New Issue
Block a user