mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Expose a compilation options instead of just the funcname
This commit is contained in:
@@ -49,7 +49,7 @@ jit_lambda_support(const char *runtimeLibPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITLambdaSupport_C support, const char *module,
|
||||
const char *funcname);
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITLambdaSupport_C support,
|
||||
@@ -76,7 +76,7 @@ library_lambda_support(const char *outputPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibraryLambdaSupport_C support, const char *module,
|
||||
const char *funcname);
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(LibraryLambdaSupport_C support,
|
||||
|
||||
@@ -36,6 +36,30 @@ protected:
|
||||
llvm::LLVMContext *llvmContext;
|
||||
};
|
||||
|
||||
/// Compilation options allows to configure the compilation pipeline.
|
||||
struct CompilationOptions {
|
||||
llvm::Optional<mlir::concretelang::V0FHEConstraint> v0FHEConstraints;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool dataflowParallelize;
|
||||
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), clientParametersFuncName(llvm::None){};
|
||||
|
||||
CompilationOptions(std::string funcname)
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), clientParametersFuncName(funcname){};
|
||||
};
|
||||
|
||||
class CompilerEngine {
|
||||
public:
|
||||
// Result of an invocation of the `CompilerEngine` with optional
|
||||
@@ -176,6 +200,27 @@ public:
|
||||
llvm::Expected<CompilerEngine::Library> compile(llvm::SourceMgr &sm,
|
||||
std::string libraryPath);
|
||||
|
||||
void setCompilationOptions(CompilationOptions &options) {
|
||||
if (options.v0FHEConstraints.hasValue()) {
|
||||
setFHEConstraints(*options.v0FHEConstraints);
|
||||
}
|
||||
|
||||
setVerifyDiagnostics(options.verifyDiagnostics);
|
||||
|
||||
setAutoParallelize(options.autoParallelize);
|
||||
setLoopParallelize(options.loopParallelize);
|
||||
setDataflowParallelize(options.dataflowParallelize);
|
||||
|
||||
if (options.clientParametersFuncName.hasValue()) {
|
||||
setGenerateClientParameters(true);
|
||||
setClientParametersFuncName(*options.clientParametersFuncName);
|
||||
}
|
||||
|
||||
if (options.fhelinalgTileSizes.hasValue()) {
|
||||
setFHELinalgTileSizes(*options.fhelinalgTileSizes);
|
||||
}
|
||||
}
|
||||
|
||||
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
|
||||
@@ -38,7 +38,7 @@ public:
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr));
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, std::string funcname = "main") override;
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override;
|
||||
using LambdaSupport::compile;
|
||||
|
||||
llvm::Expected<concretelang::JITLambda *>
|
||||
|
||||
@@ -241,25 +241,29 @@ public:
|
||||
};
|
||||
|
||||
template <typename Lambda, typename CompilationResult> class LambdaSupport {
|
||||
|
||||
public:
|
||||
typedef Lambda lambda;
|
||||
typedef CompilationResult compilationResult;
|
||||
|
||||
virtual ~LambdaSupport() {}
|
||||
|
||||
/// Compile the mlir program and produces a compilation result if succeed.
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
|
||||
llvm::SourceMgr &program, std::string funcname = "main");
|
||||
llvm::SourceMgr &program,
|
||||
CompilationOptions options = CompilationOptions("main"));
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>>
|
||||
compile(llvm::StringRef program, std::string funcname = "main") {
|
||||
return compile(llvm::MemoryBuffer::getMemBuffer(program), funcname);
|
||||
compile(llvm::StringRef program,
|
||||
CompilationOptions options = CompilationOptions("main")) {
|
||||
return compile(llvm::MemoryBuffer::getMemBuffer(program), options);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>>
|
||||
compile(std::unique_ptr<llvm::MemoryBuffer> program,
|
||||
std::string funcname = "main") {
|
||||
CompilationOptions options = CompilationOptions("main")) {
|
||||
llvm::SourceMgr sm;
|
||||
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
|
||||
return compile(sm, funcname);
|
||||
return compile(sm, options);
|
||||
}
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
@@ -312,6 +316,99 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <class LambdaSupport> class ClientServer {
|
||||
public:
|
||||
static llvm::Expected<ClientServer>
|
||||
create(llvm::StringRef program,
|
||||
CompilationOptions options = CompilationOptions("main"),
|
||||
llvm::Optional<clientlib::KeySetCache> cache = {},
|
||||
LambdaSupport support = LambdaSupport()) {
|
||||
auto compilationResult = support.compile(program, options);
|
||||
if (auto err = compilationResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto lambda = support.loadServerLambda(**compilationResult);
|
||||
if (auto err = lambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto clientParameters = support.loadClientParameters(**compilationResult);
|
||||
if (auto err = clientParameters.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto keySet = support.keySet(*clientParameters, cache);
|
||||
if (auto err = keySet.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto f = ClientServer();
|
||||
f.lambda = *lambda;
|
||||
f.compilationResult = std::move(*compilationResult);
|
||||
f.keySet = std::move(*keySet);
|
||||
f.clientParameters = *clientParameters;
|
||||
f.support = support;
|
||||
return std::move(f);
|
||||
}
|
||||
|
||||
template <typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
|
||||
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
|
||||
args, clientParameters, *this->keySet);
|
||||
|
||||
if (auto err = publicArguments.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
auto publicResult = support.serverCall(lambda, **publicArguments);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
template <typename T, typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args);
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
auto publicResult = support.serverCall(lambda, *publicArguments.value());
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
template <typename ResT = uint64_t, typename... Args>
|
||||
llvm::Expected<ResT> operator()(const Args... args) {
|
||||
auto encryptedArgs =
|
||||
clientlib::EncryptedArguments::create(*keySet, args...);
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
auto publicResult = support.serverCall(lambda, *publicArguments.value());
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
private:
|
||||
typename LambdaSupport::lambda lambda;
|
||||
std::unique_ptr<typename LambdaSupport::compilationResult> compilationResult;
|
||||
std::unique_ptr<clientlib::KeySet> keySet;
|
||||
clientlib::ClientParameters clientParameters;
|
||||
LambdaSupport support;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -36,11 +36,11 @@ public:
|
||||
LibraryLambdaSupport(std::string outputPath) : outputPath(outputPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, std::string funcname = "main") override {
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override {
|
||||
// Setup the compiler engine
|
||||
auto context = CompilationContext::createShared();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
engine.setClientParametersFuncName(funcname);
|
||||
engine.setCompilationOptions(options);
|
||||
|
||||
// Compile to a library
|
||||
auto library = engine.compile(program, outputPath);
|
||||
@@ -48,9 +48,13 @@ public:
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.clientParametersFuncName.hasValue()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->libraryPath = outputPath;
|
||||
result->funcName = funcname;
|
||||
result->funcName = *options.clientParametersFuncName;
|
||||
return std::move(result);
|
||||
}
|
||||
using LambdaSupport::compile;
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using mlir::concretelang::CompilationOptions;
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::JitLambdaSupport;
|
||||
using mlir::concretelang::LambdaArgument;
|
||||
@@ -43,19 +44,25 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
|
||||
m.def("terminate_parallelization", &terminateParallelization);
|
||||
|
||||
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def_static("build_lambda",
|
||||
[](std::string mlir_input, std::string func_name,
|
||||
std::string runtime_lib_path, std::string keysetcache_path,
|
||||
bool auto_parallelize, bool loop_parallelize,
|
||||
bool df_parallelize) {
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(),
|
||||
noEmptyStringPtr(runtime_lib_path),
|
||||
noEmptyStringPtr(keysetcache_path),
|
||||
auto_parallelize, loop_parallelize,
|
||||
df_parallelize);
|
||||
});
|
||||
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
|
||||
.def(pybind11::init(
|
||||
[](std::string funcname) { return CompilationOptions(funcname); }))
|
||||
.def("set_funcname",
|
||||
[](CompilationOptions &options, std::string funcname) {
|
||||
options.clientParametersFuncName = funcname;
|
||||
})
|
||||
.def("set_verify_diagnostics",
|
||||
[](CompilationOptions &options, bool b) {
|
||||
options.verifyDiagnostics = b;
|
||||
})
|
||||
.def("auto_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.autoParallelize = b; })
|
||||
.def("loop_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.loopParallelize = b; })
|
||||
.def("dataflow_parallelize", [](CompilationOptions &options, bool b) {
|
||||
options.dataflowParallelize = b;
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JitCompilationResult");
|
||||
pybind11::class_<mlir::concretelang::JITLambda>(m, "JITLambda");
|
||||
@@ -65,9 +72,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
}))
|
||||
.def("compile",
|
||||
[](JITLambdaSupport_C &support, std::string mlir_program,
|
||||
std::string func_name) {
|
||||
return jit_compile(support, mlir_program.c_str(),
|
||||
func_name.c_str());
|
||||
CompilationOptions options) {
|
||||
return jit_compile(support, mlir_program.c_str(), options);
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](JITLambdaSupport_C &support,
|
||||
@@ -102,9 +108,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
}))
|
||||
.def("compile",
|
||||
[](LibraryLambdaSupport_C &support, std::string mlir_program,
|
||||
std::string func_name) {
|
||||
return library_compile(support, mlir_program.c_str(),
|
||||
func_name.c_str());
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
return library_compile(support, mlir_program.c_str(), options);
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](LibraryLambdaSupport_C &support,
|
||||
|
||||
@@ -22,6 +22,8 @@ from mlir._mlir_libs._concretelang._compiler import PublicResult
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicArguments
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import CompilationOptions
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambda
|
||||
@@ -259,7 +261,7 @@ class JITCompilerSupport:
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
self._support = _JITLambdaSupport(runtime_lib_path)
|
||||
|
||||
def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult:
|
||||
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult:
|
||||
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
|
||||
|
||||
Args:
|
||||
@@ -271,7 +273,7 @@ class JITCompilerSupport:
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, func_name)
|
||||
return self._support.compile(mlir_program, options)
|
||||
|
||||
def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
@@ -299,7 +301,7 @@ class LibraryCompilerSupport:
|
||||
self._library_path = outputPath
|
||||
self._support = _LibraryLambdaSupport(outputPath)
|
||||
|
||||
def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult:
|
||||
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> LibraryCompilationResult:
|
||||
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
|
||||
|
||||
Args:
|
||||
@@ -311,9 +313,9 @@ class LibraryCompilerSupport:
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
if not isinstance(func_name, str):
|
||||
if not isinstance(options, CompilationOptions):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, func_name)
|
||||
return self._support.compile(mlir_program, options)
|
||||
|
||||
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
|
||||
"""Reload the library compilation result from the outputPath.
|
||||
|
||||
@@ -32,10 +32,10 @@ jit_lambda_support(const char *runtimeLibPath) {
|
||||
|
||||
std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITLambdaSupport_C support, const char *module,
|
||||
const char *funcname) {
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
mlir::concretelang::JitLambdaSupport esupport;
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
esupport.compile(module, funcname));
|
||||
esupport.compile(module, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
@@ -73,9 +73,9 @@ library_lambda_support(const char *outputPath) {
|
||||
|
||||
std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibraryLambdaSupport_C support, const char *module,
|
||||
const char *funcname) {
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, funcname));
|
||||
support.support.compile(module, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
|
||||
@@ -16,15 +16,14 @@ JitLambdaSupport::JitLambdaSupport(
|
||||
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
|
||||
JitLambdaSupport::compile(llvm::SourceMgr &program,
|
||||
CompilationOptions options) {
|
||||
|
||||
// Setup the compiler engine
|
||||
auto context = std::make_shared<CompilationContext>();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
|
||||
// We need client parameters to be generated
|
||||
engine.setGenerateClientParameters(true);
|
||||
engine.setClientParametersFuncName(funcname);
|
||||
engine.setCompilationOptions(options);
|
||||
|
||||
// Compile to LLVM Dialect
|
||||
auto compilationResult =
|
||||
@@ -34,10 +33,15 @@ JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.clientParametersFuncName.hasValue()) {
|
||||
return StreamStringError("Need to have a funcname to JIT compile");
|
||||
}
|
||||
|
||||
// Compile from LLVM Dialect to JITLambda
|
||||
auto mlirModule = compilationResult.get().mlirModuleRef->get();
|
||||
auto lambda = concretelang::JITLambda::create(
|
||||
funcname, mlirModule, llvmOptPipeline, runtimeLibPath);
|
||||
*options.clientParametersFuncName, mlirModule, llvmOptPipeline,
|
||||
runtimeLibPath);
|
||||
if (auto err = lambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "EndToEndFixture.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "llvm/Support/YAMLParser.h"
|
||||
#include "llvm/Support/YAMLTraits.h"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
|
||||
#include "globals.h"
|
||||
|
||||
Reference in New Issue
Block a user