From 54089186ae0bf4bfabd55e4386ac9b296da3fea1 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 7 Aug 2023 14:29:06 +0100 Subject: [PATCH] refactor(compiler): reorganize passes and add memory usage pass --- .../include/concretelang/Support/Pipeline.h | 24 +++++++-- .../compiler/lib/Support/CompilerEngine.cpp | 33 +++++++++--- .../compiler/lib/Support/Pipeline.cpp | 51 +++++++++++++++---- 3 files changed, 87 insertions(+), 21 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index 4787d0836..1ff061b3a 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -81,6 +81,11 @@ mlir::LogicalResult lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult +computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + CompilationFeedback &feedback); + mlir::LogicalResult lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, @@ -100,9 +105,9 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, bool unrollLoops); mlir::LogicalResult -lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass, - bool simulation); +addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + bool simulation); mlir::LogicalResult lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, @@ -110,8 +115,17 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass, - bool parallelizeLoops, bool gpu); + std::function enablePass); + +mlir::LogicalResult lowerToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass, + bool parallelizeLoops); + +mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass, + bool gpu); mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, llvm::Module &module); diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index b35868ed0..0601a42ba 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -521,13 +521,11 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, return std::move(res); } - // Concrete -> Canonical dialects - if (mlir::concretelang::pipeline::lowerConcreteToStd( + // Add runtime context in Concrete + if (mlir::concretelang::pipeline::addRuntimeContext( mlirContext, module, enablePass, options.simulate) .failed()) { - return StreamStringError( - "Lowering from Bufferized Concrete to canonical MLIR " - "dialects failed"); + return StreamStringError("Adding Runtime Context failed"); } // SDFG -> Canonical dialects @@ -538,12 +536,33 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, "Lowering from SDFG to canonical MLIR dialects failed"); } + // bufferize and related passes + if (mlir::concretelang::pipeline::lowerToStd(mlirContext, module, enablePass, + loopParallelize) + .failed()) { + return StreamStringError("Failed to lower to std"); + } + if (target == Target::STD) return std::move(res); + if (res.feedback) { + if (mlir::concretelang::pipeline::computeMemoryUsage( + mlirContext, module, this->enablePass, res.feedback.value()) + .failed()) { + return StreamStringError("Computing memory usage failed"); + } + } + + if (mlir::concretelang::pipeline::lowerToCAPI(mlirContext, module, enablePass, + options.emitGPUOps) + .failed()) { + return StreamStringError("Failed to lower to CAPI"); + } + // MLIR canonical dialects -> LLVM Dialect - if (mlir::concretelang::pipeline::lowerStdToLLVMDialect( - mlirContext, module, enablePass, loopParallelize, options.emitGPUOps) + if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(mlirContext, module, + enablePass) .failed()) { return StreamStringError("Failed to lower to LLVM dialect"); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index d776204fd..2844d7ba2 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -32,6 +32,7 @@ #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" #include +#include #include #include #include @@ -349,6 +350,19 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + CompilationFeedback &feedback) { + mlir::PassManager pm(&context); + pipelinePrinting("Computing Memory Usage", pm, context); + + addPotentiallyNestedPass( + pm, std::make_unique(feedback), enablePass); + + return pm.run(module.getOperation()); +} + mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { @@ -385,11 +399,11 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, } mlir::LogicalResult -lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass, - bool simulation) { +addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + bool simulation) { mlir::PassManager pm(&context); - pipelinePrinting("ConcreteToStd", pm, context); + pipelinePrinting("Adding Runtime Context", pm, context); if (!simulation) { addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(), enablePass); @@ -408,12 +422,12 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } -mlir::LogicalResult -lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass, - bool parallelizeLoops, bool gpu) { +mlir::LogicalResult lowerToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass, + bool parallelizeLoops) { mlir::PassManager pm(&context); - pipelinePrinting("StdToLLVM", pm, context); + pipelinePrinting("Lowering to Std", pm, context); // Bufferize mlir::bufferization::OneShotBufferizationOptions bufferizationOptions; @@ -467,11 +481,30 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass( pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass); + return pm.run(module); +} + +mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass, + bool gpu) { + mlir::PassManager pm(&context); + pipelinePrinting("Lowering to CAPI", pm, context); + addPotentiallyNestedPass( pm, mlir::concretelang::createConvertConcreteToCAPIPass(gpu), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass); + return pm.run(module); +} + +mlir::LogicalResult +lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("StdToLLVM", pm, context); + // Convert to MLIR LLVM Dialect addPotentiallyNestedPass( pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),