mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
refactor(compiler): reorganize passes and add memory usage pass
This commit is contained in:
@@ -81,6 +81,11 @@ mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
CompilationFeedback &feedback);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
@@ -100,9 +105,9 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
|
||||
bool unrollLoops);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool simulation);
|
||||
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool simulation);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
@@ -110,8 +115,17 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool gpu);
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops);
|
||||
|
||||
mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool gpu);
|
||||
|
||||
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
llvm::Module &module);
|
||||
|
||||
@@ -521,13 +521,11 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Concrete -> Canonical dialects
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToStd(
|
||||
// Add runtime context in Concrete
|
||||
if (mlir::concretelang::pipeline::addRuntimeContext(
|
||||
mlirContext, module, enablePass, options.simulate)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Bufferized Concrete to canonical MLIR "
|
||||
"dialects failed");
|
||||
return StreamStringError("Adding Runtime Context failed");
|
||||
}
|
||||
|
||||
// SDFG -> Canonical dialects
|
||||
@@ -538,12 +536,33 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
|
||||
"Lowering from SDFG to canonical MLIR dialects failed");
|
||||
}
|
||||
|
||||
// bufferize and related passes
|
||||
if (mlir::concretelang::pipeline::lowerToStd(mlirContext, module, enablePass,
|
||||
loopParallelize)
|
||||
.failed()) {
|
||||
return StreamStringError("Failed to lower to std");
|
||||
}
|
||||
|
||||
if (target == Target::STD)
|
||||
return std::move(res);
|
||||
|
||||
if (res.feedback) {
|
||||
if (mlir::concretelang::pipeline::computeMemoryUsage(
|
||||
mlirContext, module, this->enablePass, res.feedback.value())
|
||||
.failed()) {
|
||||
return StreamStringError("Computing memory usage failed");
|
||||
}
|
||||
}
|
||||
|
||||
if (mlir::concretelang::pipeline::lowerToCAPI(mlirContext, module, enablePass,
|
||||
options.emitGPUOps)
|
||||
.failed()) {
|
||||
return StreamStringError("Failed to lower to CAPI");
|
||||
}
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
|
||||
mlirContext, module, enablePass, loopParallelize, options.emitGPUOps)
|
||||
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Failed to lower to LLVM dialect");
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include <concretelang/Conversion/Passes.h>
|
||||
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
|
||||
#include <concretelang/Dialect/Concrete/Transforms/Passes.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
|
||||
@@ -349,6 +350,19 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
CompilationFeedback &feedback) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("Computing Memory Usage", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, std::make_unique<Concrete::MemoryUsagePass>(feedback), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
@@ -385,11 +399,11 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool simulation) {
|
||||
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool simulation) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToStd", pm, context);
|
||||
pipelinePrinting("Adding Runtime Context", pm, context);
|
||||
if (!simulation) {
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
|
||||
enablePass);
|
||||
@@ -408,12 +422,12 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool gpu) {
|
||||
mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("StdToLLVM", pm, context);
|
||||
pipelinePrinting("Lowering to Std", pm, context);
|
||||
|
||||
// Bufferize
|
||||
mlir::bufferization::OneShotBufferizationOptions bufferizationOptions;
|
||||
@@ -467,11 +481,30 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
|
||||
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool gpu) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("Lowering to CAPI", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteToCAPIPass(gpu), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass);
|
||||
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("StdToLLVM", pm, context);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
|
||||
Reference in New Issue
Block a user