mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): Add option --batch-concrete-ops and action dump-concrete-with-loops
The new option `--batch-concrete-ops` invokes the batching pass after lowering to the Concrete dialect and after lowering linalg operations with operations from the Concrete dialect to loops. The new action `dump-concrete-with-loops` dumps the IR right before batching.
This commit is contained in:
@@ -53,6 +53,7 @@ struct CompilationOptions {
|
||||
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool batchConcreteOps;
|
||||
bool dataflowParallelize;
|
||||
bool asyncOffload;
|
||||
bool optimizeConcrete;
|
||||
@@ -66,7 +67,7 @@ struct CompilationOptions {
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
|
||||
dataflowParallelize(false), asyncOffload(false), optimizeConcrete(true),
|
||||
emitGPUOps(false), clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
@@ -175,6 +176,10 @@ public:
|
||||
/// operations
|
||||
CONCRETE,
|
||||
|
||||
/// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
/// operations with all linalg ops replaced by loops
|
||||
CONCRETEWITHLOOPS,
|
||||
|
||||
/// Read sources and lower all FHE, TFHE and Concrete operations to
|
||||
/// BConcrete operations
|
||||
BCONCRETE,
|
||||
|
||||
@@ -44,6 +44,11 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool batchOperations);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
|
||||
@@ -366,6 +366,19 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
}
|
||||
|
||||
// Concrete with linalg ops -> Concrete with loop ops
|
||||
if (mlir::concretelang::pipeline::lowerConcreteLinalgToLoops(
|
||||
mlirContext, module, this->enablePass, loopParallelize,
|
||||
options.batchConcreteOps)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete with linalg ops to Concrete with loops failed");
|
||||
}
|
||||
|
||||
if (target == Target::CONCRETEWITHLOOPS) {
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass, loopParallelize)
|
||||
|
||||
@@ -239,6 +239,27 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool batchOperations) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteLinalgToLoops", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
|
||||
parallelizeLoops),
|
||||
enablePass);
|
||||
|
||||
if (batchOperations) {
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(),
|
||||
enablePass);
|
||||
}
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
@@ -246,19 +267,9 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToBConcrete", pm, context);
|
||||
|
||||
std::unique_ptr<Pass> conversionPass =
|
||||
mlir::concretelang::createConvertConcreteToBConcretePass();
|
||||
|
||||
bool passEnabled = enablePass(conversionPass.get());
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
|
||||
parallelizeLoops),
|
||||
[&](mlir::Pass *) { return passEnabled; });
|
||||
|
||||
addPotentiallyNestedPass(pm, std::move(conversionPass),
|
||||
[&](mlir::Pass *) { return passEnabled; });
|
||||
pm, mlir::concretelang::createConvertConcreteToBConcretePass(),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ target_link_libraries(concretecompiler
|
||||
TFHEDialect
|
||||
FHEDialect
|
||||
ConcretelangSupport
|
||||
ConcretelangTransforms
|
||||
|
||||
MLIRIR
|
||||
MLIRLLVMIRTransforms
|
||||
|
||||
@@ -48,6 +48,7 @@ enum Action {
|
||||
DUMP_FHE,
|
||||
DUMP_TFHE,
|
||||
DUMP_CONCRETE,
|
||||
DUMP_CONCRETEWITHLOOPS,
|
||||
DUMP_BCONCRETE,
|
||||
DUMP_STD,
|
||||
DUMP_LLVM_DIALECT,
|
||||
@@ -121,6 +122,9 @@ static llvm::cl::opt<enum Action> action(
|
||||
"Lower to TFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
|
||||
"Lower to Concrete and dump result")),
|
||||
llvm::cl::values(clEnumValN(
|
||||
Action::DUMP_CONCRETEWITHLOOPS, "dump-concrete-with-loops",
|
||||
"Lower to Concrete, replace linalg ops with loops and dump result")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete",
|
||||
"Lower to Bufferized Concrete and dump result")),
|
||||
@@ -162,6 +166,13 @@ llvm::cl::opt<bool> loopParallelize(
|
||||
"Generate (and execute if JIT) parallel loops from Linalg operations"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> batchConcreteOps(
|
||||
"batch-concrete-ops",
|
||||
llvm::cl::desc(
|
||||
"Hoist scalar Concrete operations with corresponding batched "
|
||||
"operations out of loop nests as batched operations"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> dataflowParallelize(
|
||||
"parallelize-dataflow",
|
||||
llvm::cl::desc(
|
||||
@@ -288,6 +299,7 @@ cmdlineCompilationOptions() {
|
||||
options.autoParallelize = cmdline::autoParallelize;
|
||||
options.loopParallelize = cmdline::loopParallelize;
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
options.batchConcreteOps = cmdline::batchConcreteOps;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
|
||||
@@ -457,6 +469,9 @@ mlir::LogicalResult processInputBuffer(
|
||||
case Action::DUMP_CONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
|
||||
break;
|
||||
case Action::DUMP_CONCRETEWITHLOOPS:
|
||||
target = mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS;
|
||||
break;
|
||||
case Action::DUMP_BCONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::BCONCRETE;
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user