mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat: add option to enable/disable optimization
This commit is contained in:
@@ -45,6 +45,7 @@ struct CompilationOptions {
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool dataflowParallelize;
|
||||
bool optimizeConcrete;
|
||||
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
@@ -54,7 +55,8 @@ struct CompilationOptions {
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), clientParametersFuncName(llvm::None),
|
||||
dataflowParallelize(false), optimizeConcrete(true),
|
||||
clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
|
||||
@@ -52,6 +52,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](CompilationOptions &options, bool b) {
|
||||
options.dataflowParallelize = b;
|
||||
})
|
||||
.def("set_optimize_concrete",
|
||||
[](CompilationOptions &options, bool b) {
|
||||
options.optimizeConcrete = b;
|
||||
})
|
||||
.def("set_p_error",
|
||||
[](CompilationOptions &options, double p_error) {
|
||||
options.optimizerConfig.p_error = p_error;
|
||||
|
||||
@@ -108,6 +108,19 @@ class CompilationOptions(WrapperCpp):
|
||||
raise TypeError("can't set the option to a non-boolean value")
|
||||
self.cpp().set_dataflow_parallelize(dataflow_parallelize)
|
||||
|
||||
def set_optimize_concrete(self, optimize: bool):
|
||||
"""Set flag to enable/disable optimization of concrete intermediate representation.
|
||||
|
||||
Args:
|
||||
optimize (bool): whether to turn it on or off
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not boolean
|
||||
"""
|
||||
if not isinstance(optimize, bool):
|
||||
raise TypeError("can't set the option to a non-boolean value")
|
||||
self.cpp().set_optimize_concrete(optimize)
|
||||
|
||||
def set_funcname(self, funcname: str):
|
||||
"""Set entrypoint function name.
|
||||
|
||||
|
||||
@@ -248,7 +248,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
|
||||
// Optimizing Concrete
|
||||
if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
if (this->compilerOptions.optimizeConcrete &&
|
||||
mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Optimizing Concrete failed");
|
||||
@@ -288,13 +289,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
}
|
||||
|
||||
// Optimize Concrete
|
||||
if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Optimizing Concrete failed");
|
||||
}
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass, loopParallelize)
|
||||
|
||||
@@ -91,6 +91,12 @@ llvm::cl::opt<std::string> output("o",
|
||||
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
optimizeConcrete("optimize-concrete",
|
||||
llvm::cl::desc("enable/disable optimizations of concrete "
|
||||
"dialects. (Enabled by default)"),
|
||||
llvm::cl::init<bool>(true));
|
||||
|
||||
llvm::cl::list<std::string> passes(
|
||||
"passes",
|
||||
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
|
||||
@@ -220,6 +226,7 @@ mlir::concretelang::CompilationOptions cmdlineCompilationOptions() {
|
||||
options.autoParallelize = cmdline::autoParallelize;
|
||||
options.loopParallelize = cmdline::loopParallelize;
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
|
||||
if (cmdline::assumeMaxEintPrecision.hasValue() &&
|
||||
cmdline::assumeMaxMANP.hasValue()) {
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// RUN: concretecompiler --optimize-concrete=false --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : i7
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i7) -> !Concrete.cleartext<7>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%0 = arith.constant 0 : i7
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %2: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
Reference in New Issue
Block a user