diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index b88687168..b4034279e 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -53,6 +53,7 @@ struct CompilationOptions { bool autoParallelize; bool loopParallelize; + bool batchConcreteOps; bool dataflowParallelize; bool asyncOffload; bool optimizeConcrete; @@ -66,7 +67,7 @@ struct CompilationOptions { CompilationOptions() : v0FHEConstraints(llvm::None), verifyDiagnostics(false), - autoParallelize(false), loopParallelize(false), + autoParallelize(false), loopParallelize(false), batchConcreteOps(false), dataflowParallelize(false), asyncOffload(false), optimizeConcrete(true), emitGPUOps(false), clientParametersFuncName(llvm::None), optimizerConfig(optimizer::DEFAULT_CONFIG){}; @@ -175,6 +176,10 @@ public: /// operations CONCRETE, + /// Read sources and lower all FHE and TFHE operations to Concrete + /// operations with all linalg ops replaced by loops + CONCRETEWITHLOOPS, + /// Read sources and lower all FHE, TFHE and Concrete operations to /// BConcrete operations BCONCRETE, diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index a5ef12ffb..4a2cce61c 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -44,6 +44,11 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, std::function enablePass); +mlir::LogicalResult +lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + bool parallelizeLoops, bool batchOperations); + mlir::LogicalResult lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 10386fde7..d8650f7ed 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -366,6 +366,19 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { } } + // Concrete with linalg ops -> Concrete with loop ops + if (mlir::concretelang::pipeline::lowerConcreteLinalgToLoops( + mlirContext, module, this->enablePass, loopParallelize, + options.batchConcreteOps) + .failed()) { + return StreamStringError( + "Lowering from Concrete with linalg ops to Concrete with loops failed"); + } + + if (target == Target::CONCRETEWITHLOOPS) { + return std::move(res); + } + // Concrete -> BConcrete if (mlir::concretelang::pipeline::lowerConcreteToBConcrete( mlirContext, module, this->enablePass, loopParallelize) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index c2b116712..98925c1a5 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -239,6 +239,27 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + bool parallelizeLoops, bool batchOperations) { + mlir::PassManager pm(&context); + pipelinePrinting("ConcreteLinalgToLoops", pm, context); + + addPotentiallyNestedPass( + pm, + mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass( + parallelizeLoops), + enablePass); + + if (batchOperations) { + addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(), + enablePass); + } + + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, @@ -246,19 +267,9 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::PassManager pm(&context); pipelinePrinting("ConcreteToBConcrete", pm, context); - std::unique_ptr conversionPass = - mlir::concretelang::createConvertConcreteToBConcretePass(); - - bool passEnabled = enablePass(conversionPass.get()); - addPotentiallyNestedPass( - pm, - mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass( - parallelizeLoops), - [&](mlir::Pass *) { return passEnabled; }); - - addPotentiallyNestedPass(pm, std::move(conversionPass), - [&](mlir::Pass *) { return passEnabled; }); + pm, mlir::concretelang::createConvertConcreteToBConcretePass(), + enablePass); return pm.run(module.getOperation()); } diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index 2eb39769f..171a82499 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -18,6 +18,7 @@ target_link_libraries(concretecompiler TFHEDialect FHEDialect ConcretelangSupport + ConcretelangTransforms MLIRIR MLIRLLVMIRTransforms diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 6a7860f0c..8c3b61241 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -48,6 +48,7 @@ enum Action { DUMP_FHE, DUMP_TFHE, DUMP_CONCRETE, + DUMP_CONCRETEWITHLOOPS, DUMP_BCONCRETE, DUMP_STD, DUMP_LLVM_DIALECT, @@ -121,6 +122,9 @@ static llvm::cl::opt action( "Lower to TFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete", "Lower to Concrete and dump result")), + llvm::cl::values(clEnumValN( + Action::DUMP_CONCRETEWITHLOOPS, "dump-concrete-with-loops", + "Lower to Concrete, replace linalg ops with loops and dump result")), llvm::cl::values( clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete", "Lower to Bufferized Concrete and dump result")), @@ -162,6 +166,13 @@ llvm::cl::opt loopParallelize( "Generate (and execute if JIT) parallel loops from Linalg operations"), llvm::cl::init(false)); +llvm::cl::opt batchConcreteOps( + "batch-concrete-ops", + llvm::cl::desc( + "Hoist scalar Concrete operations with corresponding batched " + "operations out of loop nests as batched operations"), + llvm::cl::init(false)); + llvm::cl::opt dataflowParallelize( "parallelize-dataflow", llvm::cl::desc( @@ -288,6 +299,7 @@ cmdlineCompilationOptions() { options.autoParallelize = cmdline::autoParallelize; options.loopParallelize = cmdline::loopParallelize; options.dataflowParallelize = cmdline::dataflowParallelize; + options.batchConcreteOps = cmdline::batchConcreteOps; options.optimizeConcrete = cmdline::optimizeConcrete; options.emitGPUOps = cmdline::emitGPUOps; @@ -457,6 +469,9 @@ mlir::LogicalResult processInputBuffer( case Action::DUMP_CONCRETE: target = mlir::concretelang::CompilerEngine::Target::CONCRETE; break; + case Action::DUMP_CONCRETEWITHLOOPS: + target = mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS; + break; case Action::DUMP_BCONCRETE: target = mlir::concretelang::CompilerEngine::Target::BCONCRETE; break;