enhance(compiler): Expose a compilation options instead of just the funcname

This commit is contained in:
Quentin Bourgerie
2022-03-17 14:34:12 +01:00
parent 1b984f5119
commit 5b83b700d2
11 changed files with 203 additions and 47 deletions

View File

@@ -49,7 +49,7 @@ jit_lambda_support(const char *runtimeLibPath);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITLambdaSupport_C support, const char *module,
const char *funcname);
mlir::concretelang::CompilationOptions options);
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITLambdaSupport_C support,
@@ -76,7 +76,7 @@ library_lambda_support(const char *outputPath);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibraryLambdaSupport_C support, const char *module,
const char *funcname);
mlir::concretelang::CompilationOptions options);
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
library_load_client_parameters(LibraryLambdaSupport_C support,

View File

@@ -36,6 +36,30 @@ protected:
llvm::LLVMContext *llvmContext;
};
/// Compilation options allows to configure the compilation pipeline.
struct CompilationOptions {
llvm::Optional<mlir::concretelang::V0FHEConstraint> v0FHEConstraints;
bool verifyDiagnostics;
bool autoParallelize;
bool loopParallelize;
bool dataflowParallelize;
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
llvm::Optional<std::string> clientParametersFuncName;
CompilationOptions()
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false),
dataflowParallelize(false), clientParametersFuncName(llvm::None){};
CompilationOptions(std::string funcname)
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false),
dataflowParallelize(false), clientParametersFuncName(funcname){};
};
class CompilerEngine {
public:
// Result of an invocation of the `CompilerEngine` with optional
@@ -176,6 +200,27 @@ public:
llvm::Expected<CompilerEngine::Library> compile(llvm::SourceMgr &sm,
std::string libraryPath);
void setCompilationOptions(CompilationOptions &options) {
if (options.v0FHEConstraints.hasValue()) {
setFHEConstraints(*options.v0FHEConstraints);
}
setVerifyDiagnostics(options.verifyDiagnostics);
setAutoParallelize(options.autoParallelize);
setLoopParallelize(options.loopParallelize);
setDataflowParallelize(options.dataflowParallelize);
if (options.clientParametersFuncName.hasValue()) {
setGenerateClientParameters(true);
setClientParametersFuncName(*options.clientParametersFuncName);
}
if (options.fhelinalgTileSizes.hasValue()) {
setFHELinalgTileSizes(*options.fhelinalgTileSizes);
}
}
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);

View File

@@ -38,7 +38,7 @@ public:
mlir::makeOptimizingTransformer(3, 0, nullptr));
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, std::string funcname = "main") override;
compile(llvm::SourceMgr &program, CompilationOptions options) override;
using LambdaSupport::compile;
llvm::Expected<concretelang::JITLambda *>

View File

@@ -241,25 +241,29 @@ public:
};
template <typename Lambda, typename CompilationResult> class LambdaSupport {
public:
typedef Lambda lambda;
typedef CompilationResult compilationResult;
virtual ~LambdaSupport() {}
/// Compile the mlir program and produces a compilation result if succeed.
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
llvm::SourceMgr &program, std::string funcname = "main");
llvm::SourceMgr &program,
CompilationOptions options = CompilationOptions("main"));
llvm::Expected<std::unique_ptr<CompilationResult>>
compile(llvm::StringRef program, std::string funcname = "main") {
return compile(llvm::MemoryBuffer::getMemBuffer(program), funcname);
compile(llvm::StringRef program,
CompilationOptions options = CompilationOptions("main")) {
return compile(llvm::MemoryBuffer::getMemBuffer(program), options);
}
llvm::Expected<std::unique_ptr<CompilationResult>>
compile(std::unique_ptr<llvm::MemoryBuffer> program,
std::string funcname = "main") {
CompilationOptions options = CompilationOptions("main")) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
return compile(sm, funcname);
return compile(sm, options);
}
/// Load the server lambda from the compilation result.
@@ -312,6 +316,99 @@ public:
}
};
template <class LambdaSupport> class ClientServer {
public:
static llvm::Expected<ClientServer>
create(llvm::StringRef program,
CompilationOptions options = CompilationOptions("main"),
llvm::Optional<clientlib::KeySetCache> cache = {},
LambdaSupport support = LambdaSupport()) {
auto compilationResult = support.compile(program, options);
if (auto err = compilationResult.takeError()) {
return std::move(err);
}
auto lambda = support.loadServerLambda(**compilationResult);
if (auto err = lambda.takeError()) {
return std::move(err);
}
auto clientParameters = support.loadClientParameters(**compilationResult);
if (auto err = clientParameters.takeError()) {
return std::move(err);
}
auto keySet = support.keySet(*clientParameters, cache);
if (auto err = keySet.takeError()) {
return std::move(err);
}
auto f = ClientServer();
f.lambda = *lambda;
f.compilationResult = std::move(*compilationResult);
f.keySet = std::move(*keySet);
f.clientParameters = *clientParameters;
f.support = support;
return std::move(f);
}
template <typename ResT = uint64_t>
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
args, clientParameters, *this->keySet);
if (auto err = publicArguments.takeError()) {
return std::move(err);
}
auto publicResult = support.serverCall(lambda, **publicArguments);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args);
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
auto publicResult = support.serverCall(lambda, *publicArguments.value());
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
template <typename ResT = uint64_t, typename... Args>
llvm::Expected<ResT> operator()(const Args... args) {
auto encryptedArgs =
clientlib::EncryptedArguments::create(*keySet, args...);
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
auto publicResult = support.serverCall(lambda, *publicArguments.value());
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
private:
typename LambdaSupport::lambda lambda;
std::unique_ptr<typename LambdaSupport::compilationResult> compilationResult;
std::unique_ptr<clientlib::KeySet> keySet;
clientlib::ClientParameters clientParameters;
LambdaSupport support;
};
} // namespace concretelang
} // namespace mlir

View File

@@ -36,11 +36,11 @@ public:
LibraryLambdaSupport(std::string outputPath) : outputPath(outputPath) {}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(llvm::SourceMgr &program, std::string funcname = "main") override {
compile(llvm::SourceMgr &program, CompilationOptions options) override {
// Setup the compiler engine
auto context = CompilationContext::createShared();
concretelang::CompilerEngine engine(context);
engine.setClientParametersFuncName(funcname);
engine.setCompilationOptions(options);
// Compile to a library
auto library = engine.compile(program, outputPath);
@@ -48,9 +48,13 @@ public:
return std::move(err);
}
if (!options.clientParametersFuncName.hasValue()) {
return StreamStringError("Need to have a funcname to compile library");
}
auto result = std::make_unique<LibraryCompilationResult>();
result->libraryPath = outputPath;
result->funcName = funcname;
result->funcName = *options.clientParametersFuncName;
return std::move(result);
}
using LambdaSupport::compile;

View File

@@ -20,6 +20,7 @@
#include <stdexcept>
#include <string>
using mlir::concretelang::CompilationOptions;
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::JitLambdaSupport;
using mlir::concretelang::LambdaArgument;
@@ -43,19 +44,25 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("terminate_parallelization", &terminateParallelization);
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
.def(pybind11::init())
.def_static("build_lambda",
[](std::string mlir_input, std::string func_name,
std::string runtime_lib_path, std::string keysetcache_path,
bool auto_parallelize, bool loop_parallelize,
bool df_parallelize) {
return buildLambda(mlir_input.c_str(), func_name.c_str(),
noEmptyStringPtr(runtime_lib_path),
noEmptyStringPtr(keysetcache_path),
auto_parallelize, loop_parallelize,
df_parallelize);
});
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
.def(pybind11::init(
[](std::string funcname) { return CompilationOptions(funcname); }))
.def("set_funcname",
[](CompilationOptions &options, std::string funcname) {
options.clientParametersFuncName = funcname;
})
.def("set_verify_diagnostics",
[](CompilationOptions &options, bool b) {
options.verifyDiagnostics = b;
})
.def("auto_parallelize", [](CompilationOptions &options,
bool b) { options.autoParallelize = b; })
.def("loop_parallelize", [](CompilationOptions &options,
bool b) { options.loopParallelize = b; })
.def("dataflow_parallelize", [](CompilationOptions &options, bool b) {
options.dataflowParallelize = b;
});
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JitCompilationResult");
pybind11::class_<mlir::concretelang::JITLambda>(m, "JITLambda");
@@ -65,9 +72,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
}))
.def("compile",
[](JITLambdaSupport_C &support, std::string mlir_program,
std::string func_name) {
return jit_compile(support, mlir_program.c_str(),
func_name.c_str());
CompilationOptions options) {
return jit_compile(support, mlir_program.c_str(), options);
})
.def("load_client_parameters",
[](JITLambdaSupport_C &support,
@@ -102,9 +108,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
}))
.def("compile",
[](LibraryLambdaSupport_C &support, std::string mlir_program,
std::string func_name) {
return library_compile(support, mlir_program.c_str(),
func_name.c_str());
mlir::concretelang::CompilationOptions options) {
return library_compile(support, mlir_program.c_str(), options);
})
.def("load_client_parameters",
[](LibraryLambdaSupport_C &support,

View File

@@ -22,6 +22,8 @@ from mlir._mlir_libs._concretelang._compiler import PublicResult
from mlir._mlir_libs._concretelang._compiler import PublicArguments
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
from mlir._mlir_libs._concretelang._compiler import CompilationOptions
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
from mlir._mlir_libs._concretelang._compiler import JITLambda
@@ -259,7 +261,7 @@ class JITCompilerSupport:
runtime_lib_path = _lookup_runtime_lib()
self._support = _JITLambdaSupport(runtime_lib_path)
def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult:
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult:
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
Args:
@@ -271,7 +273,7 @@ class JITCompilerSupport:
"""
if not isinstance(mlir_program, str):
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, func_name)
return self._support.compile(mlir_program, options)
def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
@@ -299,7 +301,7 @@ class LibraryCompilerSupport:
self._library_path = outputPath
self._support = _LibraryLambdaSupport(outputPath)
def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult:
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> LibraryCompilationResult:
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
Args:
@@ -311,9 +313,9 @@ class LibraryCompilerSupport:
"""
if not isinstance(mlir_program, str):
raise TypeError("mlir_program must be an `str`")
if not isinstance(func_name, str):
if not isinstance(options, CompilationOptions):
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, func_name)
return self._support.compile(mlir_program, options)
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
"""Reload the library compilation result from the outputPath.

View File

@@ -32,10 +32,10 @@ jit_lambda_support(const char *runtimeLibPath) {
std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITLambdaSupport_C support, const char *module,
const char *funcname) {
mlir::concretelang::CompilationOptions options) {
mlir::concretelang::JitLambdaSupport esupport;
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
esupport.compile(module, funcname));
esupport.compile(module, options));
return std::move(*compilationResult);
}
@@ -73,9 +73,9 @@ library_lambda_support(const char *outputPath) {
std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibraryLambdaSupport_C support, const char *module,
const char *funcname) {
mlir::concretelang::CompilationOptions options) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, funcname));
support.support.compile(module, options));
return std::move(*compilationResult);
}

View File

@@ -16,15 +16,14 @@ JitLambdaSupport::JitLambdaSupport(
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
llvm::Expected<std::unique_ptr<JitCompilationResult>>
JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
JitLambdaSupport::compile(llvm::SourceMgr &program,
CompilationOptions options) {
// Setup the compiler engine
auto context = std::make_shared<CompilationContext>();
concretelang::CompilerEngine engine(context);
// We need client parameters to be generated
engine.setGenerateClientParameters(true);
engine.setClientParametersFuncName(funcname);
engine.setCompilationOptions(options);
// Compile to LLVM Dialect
auto compilationResult =
@@ -34,10 +33,15 @@ JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
return std::move(err);
}
if (!options.clientParametersFuncName.hasValue()) {
return StreamStringError("Need to have a funcname to JIT compile");
}
// Compile from LLVM Dialect to JITLambda
auto mlirModule = compilationResult.get().mlirModuleRef->get();
auto lambda = concretelang::JITLambda::create(
funcname, mlirModule, llvmOptPipeline, runtimeLibPath);
*options.clientParametersFuncName, mlirModule, llvmOptPipeline,
runtimeLibPath);
if (auto err = lambda.takeError()) {
return std::move(err);
}

View File

@@ -1,7 +1,6 @@
#include "EndToEndFixture.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "llvm/Support/YAMLParser.h"
#include "llvm/Support/YAMLTraits.h"

View File

@@ -5,7 +5,7 @@
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include "llvm/Support/Path.h"
#include "globals.h"