refactor(compiler): reorganize passes and add memory usage pass

This commit is contained in:
youben11
2023-08-07 14:29:06 +01:00
committed by Ayoub Benaissa
parent d88b2c87ac
commit 54089186ae
3 changed files with 87 additions and 21 deletions

View File

@@ -81,6 +81,11 @@ mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback);
mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
@@ -100,9 +105,9 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
bool unrollLoops);
mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);
mlir::LogicalResult
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
@@ -110,8 +115,17 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu);
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops);
mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool gpu);
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module);

View File

@@ -521,13 +521,11 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
return std::move(res);
}
// Concrete -> Canonical dialects
if (mlir::concretelang::pipeline::lowerConcreteToStd(
// Add runtime context in Concrete
if (mlir::concretelang::pipeline::addRuntimeContext(
mlirContext, module, enablePass, options.simulate)
.failed()) {
return StreamStringError(
"Lowering from Bufferized Concrete to canonical MLIR "
"dialects failed");
return StreamStringError("Adding Runtime Context failed");
}
// SDFG -> Canonical dialects
@@ -538,12 +536,33 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
"Lowering from SDFG to canonical MLIR dialects failed");
}
// bufferize and related passes
if (mlir::concretelang::pipeline::lowerToStd(mlirContext, module, enablePass,
loopParallelize)
.failed()) {
return StreamStringError("Failed to lower to std");
}
if (target == Target::STD)
return std::move(res);
if (res.feedback) {
if (mlir::concretelang::pipeline::computeMemoryUsage(
mlirContext, module, this->enablePass, res.feedback.value())
.failed()) {
return StreamStringError("Computing memory usage failed");
}
}
if (mlir::concretelang::pipeline::lowerToCAPI(mlirContext, module, enablePass,
options.emitGPUOps)
.failed()) {
return StreamStringError("Failed to lower to CAPI");
}
// MLIR canonical dialects -> LLVM Dialect
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
mlirContext, module, enablePass, loopParallelize, options.emitGPUOps)
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
enablePass)
.failed()) {
return StreamStringError("Failed to lower to LLVM dialect");
}

View File

@@ -32,6 +32,7 @@
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include <concretelang/Conversion/Passes.h>
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
#include <concretelang/Dialect/Concrete/Transforms/Passes.h>
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
@@ -349,6 +350,19 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback) {
mlir::PassManager pm(&context);
pipelinePrinting("Computing Memory Usage", pm, context);
addPotentiallyNestedPass(
pm, std::make_unique<Concrete::MemoryUsagePass>(feedback), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
@@ -385,11 +399,11 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
}
mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation) {
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteToStd", pm, context);
pipelinePrinting("Adding Runtime Context", pm, context);
if (!simulation) {
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
enablePass);
@@ -408,12 +422,12 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu) {
mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops) {
mlir::PassManager pm(&context);
pipelinePrinting("StdToLLVM", pm, context);
pipelinePrinting("Lowering to Std", pm, context);
// Bufferize
mlir::bufferization::OneShotBufferizationOptions bufferizationOptions;
@@ -467,11 +481,30 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
return pm.run(module);
}
mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool gpu) {
mlir::PassManager pm(&context);
pipelinePrinting("Lowering to CAPI", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertConcreteToCAPIPass(gpu), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass);
return pm.run(module);
}
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("StdToLLVM", pm, context);
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),