mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(optimizer): expose the p_error parameter
simplify the handling of options by relying more on CompilationOptions
This commit is contained in:
@@ -49,15 +49,17 @@ struct CompilationOptions {
|
||||
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
|
||||
optimizer::Config optimizerConfig;
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), clientParametersFuncName(llvm::None){};
|
||||
dataflowParallelize(false), clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
|
||||
CompilationOptions(std::string funcname)
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), clientParametersFuncName(funcname){};
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
clientParametersFuncName = funcname;
|
||||
}
|
||||
};
|
||||
|
||||
class CompilerEngine {
|
||||
@@ -180,10 +182,9 @@ public:
|
||||
};
|
||||
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(),
|
||||
clientParametersFuncName(), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), generateClientParameters(false),
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
|
||||
generateClientParameters(
|
||||
compilerOptions.clientParametersFuncName.hasValue()),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
@@ -210,48 +211,26 @@ public:
|
||||
std::string runtimeLibraryPath = "");
|
||||
|
||||
void setCompilationOptions(CompilationOptions &options) {
|
||||
compilerOptions = options;
|
||||
if (options.v0FHEConstraints.hasValue()) {
|
||||
setFHEConstraints(*options.v0FHEConstraints);
|
||||
}
|
||||
|
||||
setVerifyDiagnostics(options.verifyDiagnostics);
|
||||
|
||||
setAutoParallelize(options.autoParallelize);
|
||||
setLoopParallelize(options.loopParallelize);
|
||||
setDataflowParallelize(options.dataflowParallelize);
|
||||
|
||||
if (options.clientParametersFuncName.hasValue()) {
|
||||
setGenerateClientParameters(true);
|
||||
setClientParametersFuncName(*options.clientParametersFuncName);
|
||||
}
|
||||
|
||||
if (options.fhelinalgTileSizes.hasValue()) {
|
||||
setFHELinalgTileSizes(*options.fhelinalgTileSizes);
|
||||
}
|
||||
}
|
||||
|
||||
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setAutoParallelize(bool v);
|
||||
void setLoopParallelize(bool v);
|
||||
void setDataflowParallelize(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision;
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool dataflowParallelize;
|
||||
CompilationOptions compilerOptions;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
|
||||
@@ -13,7 +13,17 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
llvm::Optional<V0Parameter> getV0Parameter(V0FHEConstraint constraint);
|
||||
namespace optimizer {
|
||||
constexpr double P_ERROR_4_SIGMA = 1.0 - 0.999936657516;
|
||||
struct Config {
|
||||
double p_error;
|
||||
bool display;
|
||||
};
|
||||
constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false};
|
||||
} // namespace optimizer
|
||||
|
||||
llvm::Optional<V0Parameter> getV0Parameter(V0FHEConstraint constraint,
|
||||
optimizer::Config optimizerConfig);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -48,9 +48,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
bool b) { options.autoParallelize = b; })
|
||||
.def("set_loop_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.loopParallelize = b; })
|
||||
.def("set_dataflow_parallelize", [](CompilationOptions &options, bool b) {
|
||||
options.dataflowParallelize = b;
|
||||
});
|
||||
.def("set_dataflow_parallelize",
|
||||
[](CompilationOptions &options, bool b) {
|
||||
options.dataflowParallelize = b;
|
||||
})
|
||||
.def("set_p_error",
|
||||
[](CompilationOptions &options, double p_error) {
|
||||
options.optimizerConfig.p_error = p_error;
|
||||
})
|
||||
.def("set_display_optimizer_choice",
|
||||
[](CompilationOptions &options, bool display) {
|
||||
options.optimizerConfig.display = display;
|
||||
});
|
||||
;
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JITCompilationResult");
|
||||
|
||||
@@ -120,3 +120,34 @@ class CompilationOptions(WrapperCpp):
|
||||
if not isinstance(funcname, str):
|
||||
raise TypeError("can't set the option to a non-str value")
|
||||
self.cpp().set_funcname(funcname)
|
||||
|
||||
def set_p_error(self, p_error: float):
|
||||
"""Set global error probability for each pbs.
|
||||
|
||||
Args:
|
||||
p_error (float): probability of error for each lut
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not float
|
||||
ValueError: if the value to set is not in interval ]0; 1[
|
||||
"""
|
||||
if not isinstance(p_error, float):
|
||||
raise TypeError("can't set p_error to a non-float value")
|
||||
if p_error in (0.0, 1.0):
|
||||
raise ValueError("p_error cannot be 0 or 1")
|
||||
if not 0.0 <= p_error <= 1.0:
|
||||
raise ValueError("p_error should be a probability in ]0; 1[")
|
||||
self.cpp().set_p_error(p_error)
|
||||
|
||||
def set_display_optimizer_choice(self, display: bool):
|
||||
"""Set display flag of optimizer choices.
|
||||
|
||||
Args:
|
||||
display (bool): if true the compiler display optimizer choices
|
||||
|
||||
Raises:
|
||||
TypeError: if the value is not a bool
|
||||
"""
|
||||
if not isinstance(display, bool):
|
||||
raise TypeError("display should be a bool")
|
||||
self.cpp().set_display_optimizer_choice(display)
|
||||
|
||||
@@ -96,18 +96,6 @@ void CompilerEngine::setFHEConstraints(
|
||||
this->overrideMaxMANP = c.norm2;
|
||||
}
|
||||
|
||||
void CompilerEngine::setVerifyDiagnostics(bool v) {
|
||||
this->verifyDiagnostics = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; }
|
||||
|
||||
void CompilerEngine::setLoopParallelize(bool v) { this->loopParallelize = v; }
|
||||
|
||||
void CompilerEngine::setDataflowParallelize(bool v) {
|
||||
this->dataflowParallelize = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setGenerateClientParameters(bool v) {
|
||||
this->generateClientParameters = v;
|
||||
}
|
||||
@@ -118,14 +106,6 @@ void CompilerEngine::setMaxEintPrecision(size_t v) {
|
||||
|
||||
void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; }
|
||||
|
||||
void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
|
||||
this->clientParametersFuncName = name.str();
|
||||
}
|
||||
|
||||
void CompilerEngine::setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes) {
|
||||
this->fhelinalgTileSizes = sizes.vec();
|
||||
}
|
||||
|
||||
void CompilerEngine::setEnablePass(
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
this->enablePass = enablePass;
|
||||
@@ -163,7 +143,8 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
if (!fheConstraintOrErr.get().hasValue()) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
auto fheParams = getV0Parameter(fheConstraintOrErr.get().getValue());
|
||||
auto fheParams = getV0Parameter(fheConstraintOrErr.get().getValue(),
|
||||
this->compilerOptions.optimizerConfig);
|
||||
|
||||
if (!fheParams) {
|
||||
return StreamStringError()
|
||||
@@ -202,7 +183,10 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
mlir::OwningModuleRef mlirModuleRef =
|
||||
mlir::parseSourceFile<mlir::ModuleOp>(sm, &mlirContext);
|
||||
|
||||
if (this->verifyDiagnostics) {
|
||||
CompilationOptions &options = this->compilerOptions;
|
||||
bool parallelizeLoops =
|
||||
options.autoParallelize || options.dataflowParallelize;
|
||||
if (options.verifyDiagnostics) {
|
||||
if (smHandler.verify().failed())
|
||||
return StreamStringError("Verification of diagnostics failed");
|
||||
else
|
||||
@@ -224,9 +208,9 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return std::move(err);
|
||||
|
||||
// FHELinalg tiling
|
||||
if (this->fhelinalgTileSizes) {
|
||||
if (options.fhelinalgTileSizes) {
|
||||
if (mlir::concretelang::pipeline::markFHELinalgForTiling(
|
||||
mlirContext, module, *this->fhelinalgTileSizes, enablePass)
|
||||
mlirContext, module, *options.fhelinalgTileSizes, enablePass)
|
||||
.failed())
|
||||
return errorDiag("Marking of FHELinalg operations for tiling failed");
|
||||
}
|
||||
@@ -238,7 +222,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
|
||||
// Dataflow parallelization
|
||||
if ((this->autoParallelize || this->dataflowParallelize) &&
|
||||
if (parallelizeLoops &&
|
||||
mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Dataflow parallelization failed");
|
||||
@@ -267,7 +251,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters) {
|
||||
if (!this->clientParametersFuncName.hasValue()) {
|
||||
if (!options.clientParametersFuncName.hasValue()) {
|
||||
return StreamStringError(
|
||||
"Generation of client parameters requested, but no function name "
|
||||
"specified");
|
||||
@@ -278,7 +262,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
}
|
||||
// Generate client parameters if requested
|
||||
auto funcName = this->clientParametersFuncName.getValueOr("main");
|
||||
auto funcName = options.clientParametersFuncName.getValueOr("main");
|
||||
if (this->generateClientParameters || target == Target::LIBRARY) {
|
||||
if (!res.fheContext.hasValue()) {
|
||||
// Some tests involve call a to non encrypted functions
|
||||
@@ -298,8 +282,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass,
|
||||
this->loopParallelize || this->autoParallelize)
|
||||
mlirContext, module, this->enablePass, parallelizeLoops)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete to Bufferized Concrete failed");
|
||||
@@ -321,8 +304,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
|
||||
mlirContext, module, enablePass,
|
||||
this->loopParallelize || this->autoParallelize)
|
||||
mlirContext, module, enablePass, parallelizeLoops)
|
||||
.failed()) {
|
||||
return errorDiag("Failed to lower to LLVM dialect");
|
||||
}
|
||||
|
||||
@@ -9,32 +9,76 @@
|
||||
/// from the optimizer output.
|
||||
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "concrete-optimizer.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
const double P_ERROR_4_SIGMA = 1.0 - 0.999936657516;
|
||||
static void display(V0FHEConstraint constraint,
|
||||
optimizer::Config optimizerConfig,
|
||||
concrete_optimizer::Solution sol,
|
||||
std::chrono::milliseconds duration) {
|
||||
if (!optimizerConfig.display) {
|
||||
return;
|
||||
}
|
||||
auto o = llvm::outs;
|
||||
o() << "--- Circuit\n"
|
||||
<< " " << constraint.p << " bits integers\n"
|
||||
<< constraint.norm2 << " manp (maxi log2 norm2)\n"
|
||||
|
||||
llvm::Optional<V0Parameter> getV0Parameter(V0FHEConstraint constraint) {
|
||||
<< "--- Optimizer config\n"
|
||||
<< " " << optimizerConfig.p_error << " error per pbs call\n"
|
||||
|
||||
<< "--- For each Pbs call\n"
|
||||
<< " " << (long)sol.complexity / (1000 * 1000)
|
||||
<< " Millions Operations\n"
|
||||
<< " 1/" << int(1.0 / sol.p_error) << " errors (" << sol.p_error << ")\n"
|
||||
|
||||
<< "--- Parameters resolution\n"
|
||||
<< " 2**" << (size_t)std::log2l(sol.glwe_polynomial_size)
|
||||
<< " polynomial (" << sol.glwe_polynomial_size << ")\n"
|
||||
<< " " << sol.internal_ks_output_lwe_dimension << " lwe dimension \n"
|
||||
<< " keyswitch l,b=" << sol.ks_decomposition_level_count << ","
|
||||
<< sol.ks_decomposition_base_log << "\n"
|
||||
<< " blindrota l,b=" << sol.br_decomposition_level_count << ","
|
||||
<< sol.br_decomposition_base_log << "\n"
|
||||
<< " " << sol.noise_max << " variance max\n"
|
||||
<< " " << duration.count() << "ms to solve\n"
|
||||
<< "---\n";
|
||||
}
|
||||
|
||||
llvm::Optional<V0Parameter> getV0Parameter(V0FHEConstraint constraint,
|
||||
optimizer::Config optimizerConfig) {
|
||||
namespace chrono = std::chrono;
|
||||
int security = 128;
|
||||
// the norm2 0 is equivalent to a maximum noise_factor of 2.0
|
||||
// norm2 = 0 ==> 1.0 =< noise_factor < 2.0
|
||||
// norm2 = k ==> 2^norm2 =< noise_factor < 2.0^norm2 + 1
|
||||
double noise_factor = std::exp2(constraint.norm2 + 1);
|
||||
// https://github.com/zama-ai/concrete-optimizer/blob/prototype/python/optimizer/V0Parameters/tabulation.py#L58
|
||||
double p_error = P_ERROR_4_SIGMA;
|
||||
double p_error = optimizerConfig.p_error;
|
||||
auto start = chrono::high_resolution_clock::now();
|
||||
auto sol = concrete_optimizer::optimise_bootstrap(constraint.p, security,
|
||||
noise_factor, p_error);
|
||||
|
||||
auto stop = chrono::high_resolution_clock::now();
|
||||
if (sol.p_error == 1.0) {
|
||||
// The optimizer return a p_error = 1 if there is no solution
|
||||
return llvm::None;
|
||||
}
|
||||
auto duration = chrono::duration_cast<chrono::milliseconds>(stop - start);
|
||||
auto duration_s = chrono::duration_cast<chrono::seconds>(duration);
|
||||
if (duration_s.count() > 3) {
|
||||
llvm::errs() << "concrete-optimizer time: " << duration_s.count() << "s\n";
|
||||
}
|
||||
|
||||
display(constraint, optimizerConfig, sol, duration);
|
||||
|
||||
return mlir::concretelang::V0Parameter{
|
||||
sol.glwe_dimension,
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/LLVMEmitFile.h"
|
||||
#include "concretelang/Support/Pipeline.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
@@ -178,6 +179,16 @@ llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
|
||||
llvm::cl::desc(
|
||||
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
|
||||
|
||||
llvm::cl::opt<double> pbsErrorProbability(
|
||||
"pbs-error-probability",
|
||||
llvm::cl::desc("Change the default probability of error for all pbs"),
|
||||
llvm::cl::init(mlir::concretelang::optimizer::DEFAULT_CONFIG.p_error));
|
||||
|
||||
llvm::cl::opt<bool> displayOptimizerChoice(
|
||||
"display-optimizer-choice",
|
||||
llvm::cl::desc("Display the information returned by the optimizer"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::list<int64_t> fhelinalgTileSizes(
|
||||
"fhelinalg-tile-sizes",
|
||||
llvm::cl::desc(
|
||||
@@ -185,35 +196,6 @@ llvm::cl::list<int64_t> fhelinalgTileSizes(
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
} // namespace cmdline
|
||||
|
||||
llvm::Expected<mlir::concretelang::V0FHEContext> buildFHEContext(
|
||||
llvm::Optional<mlir::concretelang::V0FHEConstraint> autoFHEConstraints,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP) {
|
||||
if (!autoFHEConstraints.hasValue() &&
|
||||
(!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) {
|
||||
return mlir::concretelang::StreamStringError(
|
||||
"Maximum encrypted integer precision and maximum for the Minimal"
|
||||
"Arithmetic Noise Passing are required, but were neither specified"
|
||||
"explicitly nor determined automatically");
|
||||
}
|
||||
|
||||
mlir::concretelang::V0FHEConstraint fheConstraints{
|
||||
overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
|
||||
: autoFHEConstraints.getValue().norm2,
|
||||
overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue()
|
||||
: autoFHEConstraints.getValue().p};
|
||||
|
||||
auto parameter = getV0Parameter(fheConstraints);
|
||||
|
||||
if (!parameter) {
|
||||
return mlir::concretelang::StreamStringError()
|
||||
<< "Could not determine V0 parameters for 2-norm of "
|
||||
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
|
||||
}
|
||||
|
||||
return mlir::concretelang::V0FHEContext{fheConstraints, parameter.getValue()};
|
||||
}
|
||||
|
||||
namespace llvm {
|
||||
// This needs to be wrapped into the llvm namespace for proper
|
||||
// operator lookup
|
||||
@@ -231,6 +213,36 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
}
|
||||
} // namespace llvm
|
||||
|
||||
mlir::concretelang::CompilationOptions cmdlineCompilationOptions() {
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
|
||||
options.verifyDiagnostics = cmdline::verifyDiagnostics;
|
||||
options.autoParallelize = cmdline::autoParallelize;
|
||||
options.loopParallelize = cmdline::loopParallelize;
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
|
||||
if (cmdline::assumeMaxEintPrecision.hasValue() &&
|
||||
cmdline::assumeMaxMANP.hasValue()) {
|
||||
options.v0FHEConstraints = mlir::concretelang::V0FHEConstraint{
|
||||
cmdline::assumeMaxMANP.getValue().getValue(),
|
||||
cmdline::assumeMaxEintPrecision.getValue().getValue(),
|
||||
};
|
||||
}
|
||||
|
||||
if (!cmdline::funcName.empty()) {
|
||||
options.clientParametersFuncName = cmdline::funcName;
|
||||
}
|
||||
|
||||
// Convert tile sizes to `Optional`
|
||||
if (!cmdline::fhelinalgTileSizes.empty())
|
||||
options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
|
||||
|
||||
options.optimizerConfig.p_error = cmdline::pbsErrorProbability;
|
||||
options.optimizerConfig.display = cmdline::displayOptimizerChoice;
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
// Process a single source buffer
|
||||
//
|
||||
// The parameter `action` specifies how the buffer should be processed
|
||||
@@ -259,36 +271,14 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
// Compilation output is written to the stream specified by `os`.
|
||||
mlir::LogicalResult processInputBuffer(
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer, std::string sourceFileName,
|
||||
enum Action action, const std::string &funcName,
|
||||
mlir::concretelang::CompilationOptions &options, enum Action action,
|
||||
llvm::ArrayRef<uint64_t> jitArgs,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
llvm::Optional<llvm::ArrayRef<int64_t>> fhelinalgTileSizes,
|
||||
bool autoParallelize, bool loopParallelize, bool dataflowParallelize,
|
||||
llvm::Optional<clientlib::KeySetCache> keySetCache, llvm::raw_ostream &os,
|
||||
std::shared_ptr<mlir::concretelang::CompilerEngine::Library> outputLib) {
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
|
||||
options.verifyDiagnostics = verifyDiagnostics;
|
||||
options.autoParallelize = autoParallelize;
|
||||
options.loopParallelize = loopParallelize;
|
||||
options.dataflowParallelize = dataflowParallelize;
|
||||
|
||||
if (overrideMaxEintPrecision.hasValue() && overrideMaxMANP.hasValue())
|
||||
options.v0FHEConstraints = {
|
||||
overrideMaxMANP.hasValue(),
|
||||
overrideMaxEintPrecision.hasValue(),
|
||||
};
|
||||
|
||||
if (!funcName.empty())
|
||||
options.clientParametersFuncName = funcName;
|
||||
|
||||
if (fhelinalgTileSizes.hasValue())
|
||||
options.fhelinalgTileSizes = *fhelinalgTileSizes;
|
||||
|
||||
std::string funcName = options.clientParametersFuncName.getValueOr("");
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
auto lambdaOrErr =
|
||||
mlir::concretelang::ClientServer<mlir::concretelang::JITSupport>::
|
||||
@@ -373,7 +363,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
retOrErr->llvmModule->setModuleIdentifier(sourceFileName);
|
||||
}
|
||||
|
||||
if (verifyDiagnostics) {
|
||||
if (options.verifyDiagnostics) {
|
||||
return mlir::success();
|
||||
} else if (action == Action::DUMP_LLVM_IR ||
|
||||
action == Action::DUMP_OPTIMIZED_LLVM_IR) {
|
||||
@@ -412,11 +402,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tile sizes to `Optional`
|
||||
llvm::Optional<llvm::ArrayRef<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
if (!cmdline::fhelinalgTileSizes.empty())
|
||||
fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
|
||||
auto compilerOptions = cmdlineCompilationOptions();
|
||||
|
||||
llvm::Optional<clientlib::KeySetCache> jitKeySetCache;
|
||||
if (!cmdline::jitKeySetCachePath.empty()) {
|
||||
@@ -454,12 +440,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
auto process = [&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(
|
||||
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
|
||||
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
|
||||
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
|
||||
fhelinalgTileSizes, cmdline::autoParallelize,
|
||||
cmdline::loopParallelize, cmdline::dataflowParallelize,
|
||||
jitKeySetCache, os, outputLib);
|
||||
std::move(inputBuffer), fileName, compilerOptions, cmdline::action,
|
||||
cmdline::jitArgs, jitKeySetCache, os, outputLib);
|
||||
};
|
||||
auto &os = output->os();
|
||||
auto res = mlir::failure();
|
||||
|
||||
@@ -36,7 +36,8 @@ compile(std::string outputLib, std::string source,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
ce.setClientParametersFuncName(funcname);
|
||||
mlir::concretelang::CompilationOptions options(funcname);
|
||||
ce.setCompilationOptions(options);
|
||||
auto result = ce.compile(sources, outputLib);
|
||||
if (!result) {
|
||||
llvm::errs() << result.takeError();
|
||||
|
||||
@@ -182,6 +182,23 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_ca
|
||||
assert_result(result, expected_result)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
mlir_input = """
|
||||
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
"""
|
||||
args = (73,)
|
||||
expected_result = 73
|
||||
engine = LibrarySupport.new("./py_test_lib_compile_and_run_custom_perror")
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options)
|
||||
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result", end_to_end_parallel_fixture
|
||||
|
||||
Reference in New Issue
Block a user