mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(compiler): Separate TFHE parametrization into its own pipeline stage
This commit is contained in:
@@ -208,6 +208,10 @@ public:
|
||||
/// operations
|
||||
TFHE,
|
||||
|
||||
/// Read sources and lower all FHE operations to TFHE
|
||||
/// operations, then parametrize the TFHE operations
|
||||
PARAMETRIZED_TFHE,
|
||||
|
||||
/// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
/// operations
|
||||
CONCRETE,
|
||||
|
||||
@@ -58,9 +58,13 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
parametrizeTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
|
||||
@@ -404,9 +404,18 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::TFHE)
|
||||
return std::move(res);
|
||||
|
||||
if (mlir::concretelang::pipeline::parametrizeTFHE(mlirContext, module,
|
||||
res.fheContext, enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Parametrization of TFHE operations failed");
|
||||
}
|
||||
|
||||
if (target == Target::PARAMETRIZED_TFHE)
|
||||
return std::move(res);
|
||||
|
||||
// TFHE -> Concrete
|
||||
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
|
||||
mlirContext, module, res.fheContext, this->enablePass)
|
||||
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from TFHE to Concrete failed");
|
||||
}
|
||||
|
||||
@@ -270,11 +270,11 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
parametrizeTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("TFHEToConcrete", pm, context);
|
||||
pipelinePrinting("ParametrizeTFHE", pm, context);
|
||||
|
||||
if (fheContext.has_value()) {
|
||||
addPotentiallyNestedPass(
|
||||
@@ -284,6 +284,15 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
enablePass);
|
||||
}
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("TFHEToConcrete", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertTFHEToConcretePass(), enablePass);
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ enum Action {
|
||||
DUMP_FHE,
|
||||
DUMP_FHE_NO_LINALG,
|
||||
DUMP_TFHE,
|
||||
DUMP_PARAMETRIZED_TFHE,
|
||||
DUMP_CONCRETE,
|
||||
DUMP_SDFG,
|
||||
DUMP_STD,
|
||||
@@ -124,6 +125,9 @@ static llvm::cl::opt<enum Action> action(
|
||||
"Lower FHELinalg to FHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
|
||||
"Lower to TFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(
|
||||
Action::DUMP_PARAMETRIZED_TFHE, "dump-parametrized-tfhe",
|
||||
"Lower to TFHE, parametrize TFHE operations and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
|
||||
"Lower to Concrete and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_SDFG, "dump-sdfg",
|
||||
@@ -525,6 +529,9 @@ mlir::LogicalResult processInputBuffer(
|
||||
case Action::DUMP_TFHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::TFHE;
|
||||
break;
|
||||
case Action::DUMP_PARAMETRIZED_TFHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
|
||||
break;
|
||||
case Action::DUMP_CONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user