feat(compiler): Add option --batch-concrete-ops and action dump-concrete-with-loops

The new option `--batch-concrete-ops` invokes the batching pass after
lowering to the Concrete dialect and after lowering linalg operations
with operations from the Concrete dialect to loops.

The new action `dump-concrete-with-loops` dumps the IR right before
batching.
This commit is contained in:
Andi Drebes
2022-11-10 16:55:16 +01:00
parent 75b70054b2
commit c9bb6541e9
6 changed files with 63 additions and 13 deletions

View File

@@ -53,6 +53,7 @@ struct CompilationOptions {
bool autoParallelize;
bool loopParallelize;
bool batchConcreteOps;
bool dataflowParallelize;
bool asyncOffload;
bool optimizeConcrete;
@@ -66,7 +67,7 @@ struct CompilationOptions {
CompilationOptions()
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false),
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
dataflowParallelize(false), asyncOffload(false), optimizeConcrete(true),
emitGPUOps(false), clientParametersFuncName(llvm::None),
optimizerConfig(optimizer::DEFAULT_CONFIG){};
@@ -175,6 +176,10 @@ public:
/// operations
CONCRETE,
/// Read sources and lower all FHE and TFHE operations to Concrete
/// operations with all linalg ops replaced by loops
CONCRETEWITHLOOPS,
/// Read sources and lower all FHE, TFHE and Concrete operations to
/// BConcrete operations
BCONCRETE,

View File

@@ -44,6 +44,11 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool batchOperations);
mlir::LogicalResult
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,

View File

@@ -366,6 +366,19 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
}
}
// Concrete with linalg ops -> Concrete with loop ops
if (mlir::concretelang::pipeline::lowerConcreteLinalgToLoops(
mlirContext, module, this->enablePass, loopParallelize,
options.batchConcreteOps)
.failed()) {
return StreamStringError(
"Lowering from Concrete with linalg ops to Concrete with loops failed");
}
if (target == Target::CONCRETEWITHLOOPS) {
return std::move(res);
}
// Concrete -> BConcrete
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
mlirContext, module, this->enablePass, loopParallelize)

View File

@@ -239,6 +239,27 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool batchOperations) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteLinalgToLoops", pm, context);
addPotentiallyNestedPass(
pm,
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
parallelizeLoops),
enablePass);
if (batchOperations) {
addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(),
enablePass);
}
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
@@ -246,19 +267,9 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteToBConcrete", pm, context);
std::unique_ptr<Pass> conversionPass =
mlir::concretelang::createConvertConcreteToBConcretePass();
bool passEnabled = enablePass(conversionPass.get());
addPotentiallyNestedPass(
pm,
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
parallelizeLoops),
[&](mlir::Pass *) { return passEnabled; });
addPotentiallyNestedPass(pm, std::move(conversionPass),
[&](mlir::Pass *) { return passEnabled; });
pm, mlir::concretelang::createConvertConcreteToBConcretePass(),
enablePass);
return pm.run(module.getOperation());
}

View File

@@ -18,6 +18,7 @@ target_link_libraries(concretecompiler
TFHEDialect
FHEDialect
ConcretelangSupport
ConcretelangTransforms
MLIRIR
MLIRLLVMIRTransforms

View File

@@ -48,6 +48,7 @@ enum Action {
DUMP_FHE,
DUMP_TFHE,
DUMP_CONCRETE,
DUMP_CONCRETEWITHLOOPS,
DUMP_BCONCRETE,
DUMP_STD,
DUMP_LLVM_DIALECT,
@@ -121,6 +122,9 @@ static llvm::cl::opt<enum Action> action(
"Lower to TFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
"Lower to Concrete and dump result")),
llvm::cl::values(clEnumValN(
Action::DUMP_CONCRETEWITHLOOPS, "dump-concrete-with-loops",
"Lower to Concrete, replace linalg ops with loops and dump result")),
llvm::cl::values(
clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete",
"Lower to Bufferized Concrete and dump result")),
@@ -162,6 +166,13 @@ llvm::cl::opt<bool> loopParallelize(
"Generate (and execute if JIT) parallel loops from Linalg operations"),
llvm::cl::init(false));
llvm::cl::opt<bool> batchConcreteOps(
"batch-concrete-ops",
llvm::cl::desc(
"Hoist scalar Concrete operations with corresponding batched "
"operations out of loop nests as batched operations"),
llvm::cl::init(false));
llvm::cl::opt<bool> dataflowParallelize(
"parallelize-dataflow",
llvm::cl::desc(
@@ -288,6 +299,7 @@ cmdlineCompilationOptions() {
options.autoParallelize = cmdline::autoParallelize;
options.loopParallelize = cmdline::loopParallelize;
options.dataflowParallelize = cmdline::dataflowParallelize;
options.batchConcreteOps = cmdline::batchConcreteOps;
options.optimizeConcrete = cmdline::optimizeConcrete;
options.emitGPUOps = cmdline::emitGPUOps;
@@ -457,6 +469,9 @@ mlir::LogicalResult processInputBuffer(
case Action::DUMP_CONCRETE:
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
break;
case Action::DUMP_CONCRETEWITHLOOPS:
target = mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS;
break;
case Action::DUMP_BCONCRETE:
target = mlir::concretelang::CompilerEngine::Target::BCONCRETE;
break;