refactor(compiler): Separate TFHE parametrization into its own pipeline stage

This commit is contained in:
Andi Drebes
2023-03-23 10:54:16 +01:00
parent fdb4594a2b
commit 9cd238db82
5 changed files with 40 additions and 7 deletions

View File

@@ -208,6 +208,10 @@ public:
/// operations
TFHE,
/// Read sources and lower all FHE operations to TFHE
/// operations, then parametrize the TFHE operations
PARAMETRIZED_TFHE,
/// Read sources and lower all FHE and TFHE operations to Concrete
/// operations
CONCRETE,

View File

@@ -58,9 +58,13 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
parametrizeTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult

View File

@@ -404,9 +404,18 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::TFHE)
return std::move(res);
if (mlir::concretelang::pipeline::parametrizeTFHE(mlirContext, module,
res.fheContext, enablePass)
.failed()) {
return errorDiag("Parametrization of TFHE operations failed");
}
if (target == Target::PARAMETRIZED_TFHE)
return std::move(res);
// TFHE -> Concrete
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
mlirContext, module, res.fheContext, this->enablePass)
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(mlirContext, module,
this->enablePass)
.failed()) {
return errorDiag("Lowering from TFHE to Concrete failed");
}

View File

@@ -270,11 +270,11 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
}
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
parametrizeTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("TFHEToConcrete", pm, context);
pipelinePrinting("ParametrizeTFHE", pm, context);
if (fheContext.has_value()) {
addPotentiallyNestedPass(
@@ -284,6 +284,15 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
enablePass);
}
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("TFHEToConcrete", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertTFHEToConcretePass(), enablePass);

View File

@@ -49,6 +49,7 @@ enum Action {
DUMP_FHE,
DUMP_FHE_NO_LINALG,
DUMP_TFHE,
DUMP_PARAMETRIZED_TFHE,
DUMP_CONCRETE,
DUMP_SDFG,
DUMP_STD,
@@ -124,6 +125,9 @@ static llvm::cl::opt<enum Action> action(
"Lower FHELinalg to FHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
"Lower to TFHE and dump result")),
llvm::cl::values(clEnumValN(
Action::DUMP_PARAMETRIZED_TFHE, "dump-parametrized-tfhe",
"Lower to TFHE, parametrize TFHE operations and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
"Lower to Concrete and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_SDFG, "dump-sdfg",
@@ -525,6 +529,9 @@ mlir::LogicalResult processInputBuffer(
case Action::DUMP_TFHE:
target = mlir::concretelang::CompilerEngine::Target::TFHE;
break;
case Action::DUMP_PARAMETRIZED_TFHE:
target = mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
break;
case Action::DUMP_CONCRETE:
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
break;