feat(python): Expose Jit and Library compiler support

This commit is contained in:
Quentin Bourgerie
2022-03-15 14:26:15 +01:00
parent f8968eb489
commit 8867d313ee
9 changed files with 711 additions and 160 deletions

View File

@@ -9,6 +9,8 @@
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include "concretelang/Support/LibraryLambdaSupport.h"
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
@@ -35,11 +37,79 @@ struct executionArguments {
};
typedef struct executionArguments executionArguments;
// JIT Support bindings ///////////////////////////////////////////////////////
struct JITLambdaSupport_C {
mlir::concretelang::JitLambdaSupport support;
};
typedef struct JITLambdaSupport_C JITLambdaSupport_C;
MLIR_CAPI_EXPORTED JITLambdaSupport_C
jit_lambda_support(const char *runtimeLibPath);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITLambdaSupport_C support, const char *module,
const char *funcname);
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
jit_load_server_lambda(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITLambdaSupport_C support,
mlir::concretelang::JITLambda *lambda,
concretelang::clientlib::PublicArguments &args);
// Library Support bindings ///////////////////////////////////////////////////
struct LibraryLambdaSupport_C {
mlir::concretelang::LibraryLambdaSupport support;
};
typedef struct LibraryLambdaSupport_C LibraryLambdaSupport_C;
MLIR_CAPI_EXPORTED LibraryLambdaSupport_C
library_lambda_support(const char *outputPath);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibraryLambdaSupport_C support, const char *module,
const char *funcname);
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
library_load_client_parameters(LibraryLambdaSupport_C support,
mlir::concretelang::LibraryCompilationResult &);
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
library_load_server_lambda(LibraryLambdaSupport_C support,
mlir::concretelang::LibraryCompilationResult &);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibraryLambdaSupport_C support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args);
// Client Support bindings ///////////////////////////////////////////////////
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
key_set(concretelang::clientlib::ClientParameters clientParameters,
llvm::Optional<concretelang::clientlib::KeySetCache> cache);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
MLIR_CAPI_EXPORTED lambdaArgument
decrypt_result(concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult);
// Build lambda from a textual representation of an MLIR module
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
// null) as a shared library during compilation,
// a path to activate the use a cache for encryption keys for test purpose
// (unsecure), and a set of flags for parallelization.
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if
// not null) as a shared library during compilation, a path to activate the
// use a cache for encryption keys for test purpose (unsecure).
MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath, const char *keySetCachePath,

View File

@@ -35,8 +35,7 @@ public:
JitLambdaSupport(
llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None,
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr))
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
mlir::makeOptimizingTransformer(3, 0, nullptr));
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, std::string funcname = "main") override;

View File

@@ -33,8 +33,7 @@ class LibraryLambdaSupport
: public LambdaSupport<serverlib::ServerLambda, LibraryCompilationResult> {
public:
LibraryLambdaSupport(std::string outputPath = "/tmp/toto")
: outputPath(outputPath) {}
LibraryLambdaSupport(std::string outputPath) : outputPath(outputPath) {}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(llvm::SourceMgr &program, std::string funcname = "main") override {

View File

@@ -8,6 +8,7 @@
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
@@ -20,6 +21,7 @@
#include <string>
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::JitLambdaSupport;
using mlir::concretelang::LambdaArgument;
const char *noEmptyStringPtr(std::string &s) {
@@ -54,6 +56,110 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
auto_parallelize, loop_parallelize,
df_parallelize);
});
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JitCompilationResult");
pybind11::class_<mlir::concretelang::JITLambda>(m, "JITLambda");
pybind11::class_<JITLambdaSupport_C>(m, "JITLambdaSupport")
.def(pybind11::init([](std::string runtimeLibPath) {
return jit_lambda_support(runtimeLibPath.c_str());
}))
.def("compile",
[](JITLambdaSupport_C &support, std::string mlir_program,
std::string func_name) {
return jit_compile(support, mlir_program.c_str(),
func_name.c_str());
})
.def("load_client_parameters",
[](JITLambdaSupport_C &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_client_parameters(support, result);
})
.def(
"load_server_lambda",
[](JITLambdaSupport_C &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_server_lambda(support, result);
},
pybind11::return_value_policy::reference)
.def("server_call",
[](JITLambdaSupport_C &support, concretelang::JITLambda *lambda,
clientlib::PublicArguments &publicArguments) {
return jit_server_call(support, lambda, publicArguments);
});
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
m, "LibraryCompilationResult")
.def(pybind11::init([](std::string libraryPath, std::string funcname) {
return mlir::concretelang::LibraryCompilationResult{
libraryPath,
funcname,
};
}));
pybind11::class_<concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
pybind11::class_<LibraryLambdaSupport_C>(m, "LibraryLambdaSupport")
.def(pybind11::init([](std::string outputPath) {
return library_lambda_support(outputPath.c_str());
}))
.def("compile",
[](LibraryLambdaSupport_C &support, std::string mlir_program,
std::string func_name) {
return library_compile(support, mlir_program.c_str(),
func_name.c_str());
})
.def("load_client_parameters",
[](LibraryLambdaSupport_C &support,
mlir::concretelang::LibraryCompilationResult &result) {
return library_load_client_parameters(support, result);
})
.def(
"load_server_lambda",
[](LibraryLambdaSupport_C &support,
mlir::concretelang::LibraryCompilationResult &result) {
return library_load_server_lambda(support, result);
},
pybind11::return_value_policy::reference)
.def("server_call",
[](LibraryLambdaSupport_C &support, serverlib::ServerLambda lambda,
clientlib::PublicArguments &publicArguments) {
return library_server_call(support, lambda, publicArguments);
});
class ClientSupport {};
pybind11::class_<ClientSupport>(m, "ClientSupport")
.def(pybind11::init())
.def_static(
"key_set",
[](clientlib::ClientParameters clientParameters,
clientlib::KeySetCache *cache) {
auto optCache =
cache == nullptr
? llvm::None
: llvm::Optional<clientlib::KeySetCache>(*cache);
return key_set(clientParameters, optCache);
},
pybind11::arg().none(false), pybind11::arg().none(true))
.def_static("encrypt_arguments",
[](clientlib::ClientParameters clientParameters,
clientlib::KeySet &keySet,
std::vector<lambdaArgument> args) {
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
for (auto i = 0u; i < args.size(); i++) {
argsRef.push_back(args[i].ptr.get());
}
return encrypt_arguments(clientParameters, keySet, argsRef);
})
.def_static("decrypt_result", [](clientlib::KeySet &keySet,
clientlib::PublicResult &publicResult) {
return decrypt_result(keySet, publicResult);
});
pybind11::class_<KeySetCache>(m, "KeySetCache")
.def(pybind11::init<std::string &>());
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters");
pybind11::class_<clientlib::KeySet>(m, "KeySet");
pybind11::class_<clientlib::PublicArguments>(m, "PublicArguments");
pybind11::class_<clientlib::PublicResult>(m, "PublicResult");
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def_static("from_tensor",

View File

@@ -14,6 +14,23 @@ from mlir._mlir_libs._concretelang._compiler import (
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
from mlir._mlir_libs._concretelang._compiler import library as _library
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport
from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
from mlir._mlir_libs._concretelang._compiler import ClientParameters
from mlir._mlir_libs._concretelang._compiler import KeySet
from mlir._mlir_libs._concretelang._compiler import KeySetCache
from mlir._mlir_libs._concretelang._compiler import PublicResult
from mlir._mlir_libs._concretelang._compiler import PublicArguments
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
from mlir._mlir_libs._concretelang._compiler import JITLambda
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
import numpy as np
@@ -46,7 +63,8 @@ def _lookup_runtime_lib() -> str:
for filename in os.listdir(libs_path)
if filename.startswith("libConcretelangRuntime")
]
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
assert len(
runtime_library_paths) == 1, "should be one and only one runtime library"
return os.path.join(libs_path, runtime_library_paths[0])
@@ -67,10 +85,10 @@ def round_trip(mlir_str: str) -> str:
return _round_trip(mlir_str)
_MLIR_MODULES_TYPE = "mlir_modules must be an `iterable` of `str` or a `str"
_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str'
def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str:
def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str:
"""Compile the MLIR inputs to a library.
Args:
@@ -101,7 +119,7 @@ def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str
return _library(library_path, mlir_modules)
def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument":
def create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument:
"""Create an execution argument holding either an int or tensor value.
Args:
@@ -115,8 +133,7 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
"""
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError(
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
)
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
if isinstance(value, ACCEPTED_INTS):
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
raise TypeError(
@@ -134,7 +151,7 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = _JitCompilerEngine()
self._engine = JITCompilerSupport()
self._lambda = None
if mlir_str is not None:
self.compile_fhe(mlir_str)
@@ -182,17 +199,17 @@ class CompilerEngine:
)
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
if not isinstance(unsecure_key_set_cache_path, str):
raise TypeError("unsecure_key_set_cache_path must be a str")
self._lambda = self._engine.build_lambda(
mlir_str,
func_name,
runtime_lib_path,
unsecure_key_set_cache_path,
auto_parallelize,
loop_parallelize,
df_parallelize,
)
raise TypeError(
"unsecure_key_set_cache_path must be a str"
)
self._compilation_result = self._engine.compile(mlir_str)
self._client_parameters = self._engine.load_client_parameters(
self._compilation_result)
keyset_cache = None
if not unsecure_key_set_cache_path is None:
keyset_cache = KeySetCache(unsecure_key_set_cache_path)
self._key_set = ClientSupport.key_set(
self._client_parameters, keyset_cache)
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
"""Run the compiled code.
@@ -208,10 +225,59 @@ class CompilerEngine:
Returns:
int or numpy.array: result of execution.
"""
if self._lambda is None:
if self._compilation_result is None:
raise RuntimeError("need to compile an MLIR code first")
# Client
public_arguments = ClientSupport.encrypt_arguments(self._client_parameters,
self._key_set, args)
# Server
server_lambda = self._engine.load_server_lambda(
self._compilation_result)
public_result = self._engine.server_call(
server_lambda, public_arguments)
# Client
return ClientSupport.decrypt_result(self._key_set, public_result)
class ClientSupport:
def key_set(client_parameters: ClientParameters, cache: KeySetCache = None) -> KeySet:
"""Generates a key set according to the given client parameters.
If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache
Args:
client_parameters: A client parameters specification
cache: An optional cache of key set.
Returns:
KeySet: the key set
"""
return _ClientSupport.key_set(client_parameters, cache)
def encrypt_arguments(client_parameters: ClientParameters, key_set: KeySet, args: List[Union[int, np.ndarray]]) -> PublicArguments:
"""Export clear arguments to public arguments.
For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object.
Args:
client_parameters: A client parameters specification
key_set: A key set used to encrypt encrypted arguments
Returns:
PublicArguments: the public arguments
"""
execution_arguments = [create_execution_argument(arg) for arg in args]
lambda_arg = self._lambda.invoke(execution_arguments)
return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments)
def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]:
"""Decrypt a public result thanks the given key set.
Args:
key_set: The key set used to decrypt the result.
public_result: The public result to descrypt.
Returns:
int or numpy.array: The result of decryption.
"""
lambda_arg = _ClientSupport.decrypt_result(key_set, public_result)
if lambda_arg.is_scalar():
return lambda_arg.get_scalar()
elif lambda_arg.is_tensor():
@@ -220,3 +286,103 @@ class CompilerEngine:
return tensor
else:
raise RuntimeError("unknown return type")
class JITCompilerSupport:
def __init__(self, runtime_lib_path=None):
if runtime_lib_path is None:
runtime_lib_path = _lookup_runtime_lib()
self._support = JITLambdaSupport(runtime_lib_path)
def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult:
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
Args:
mlir_program: A textual representation of the mlir program to compile.
func_name: The name of the function to compile.
Returns:
JITCompilationResult: the result of the JIT compilation.
"""
if not isinstance(mlir_program, str):
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, func_name)
def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
return self._support.load_client_parameters(compilation_result)
def load_server_lambda(self, compilation_result: JitCompilationResult) -> JITLambda:
"""Load the server lambda from the JIT compilation result"""
return self._support.load_server_lambda(compilation_result)
def server_call(self, server_lambda: JITLambda, public_arguments: PublicArguments):
"""Call the server lambda with public_arguments
Args:
server_lambda: A server lambda to call
public_arguments: The arguments of the call
Returns:
PublicResult: the result of the call of the server lambda
"""
return self._support.server_call(server_lambda, public_arguments)
class LibraryCompilerSupport:
def __init__(self, outputPath="./out"):
self._library_path = outputPath
self._support = LibraryLambdaSupport(outputPath)
def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult:
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
Args:
mlir_program: A textual representation of the mlir program to compile.
func_name: The name of the function to compile.
Returns:
LibraryCompilationResult: the result of the compilation.
"""
if not isinstance(mlir_program, str):
raise TypeError("mlir_program must be an `str`")
if not isinstance(func_name, str):
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, func_name)
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
"""Reload the library compilation result from the outputPath.
Args:
library-path: The path of the compiled library.
func_name: The name of the compiled function.
Returns:
LibraryCompilationResult: the result of a compilation.
"""
if not isinstance(func_name, str):
raise TypeError("func_name must be an `str`")
return LibraryCompilationResult(self._library_path, func_name)
def load_client_parameters(self, compilation_result: LibraryCompilationResult) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError(
"compilation_result must be an `LibraryCompilationResult`")
return self._support.load_client_parameters(compilation_result)
def load_server_lambda(self, compilation_result: LibraryCompilationResult) -> LibraryLambda:
"""Load the server lambda from the JIT compilation result"""
return self._support.load_server_lambda(compilation_result)
def server_call(self, server_lambda: LibraryLambda, public_arguments: PublicArguments) -> PublicResult:
"""Call the server lambda with public_arguments
Args:
server_lambda: A server lambda to call
public_arguments: The arguments of the call
Returns:
PublicResult: the result of the call of the server lambda
"""
return self._support.server_call(server_lambda, public_arguments)

View File

@@ -11,9 +11,134 @@
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
using mlir::concretelang::JitCompilerEngine;
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
auto VARNAME = EXPECTED; \
if (auto err = VARNAME.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
}
// JIT Support bindings ///////////////////////////////////////////////////////
MLIR_CAPI_EXPORTED JITLambdaSupport_C
jit_lambda_support(const char *runtimeLibPath) {
llvm::StringRef str(runtimeLibPath);
auto opt = str.empty() ? llvm::None : llvm::Optional<llvm::StringRef>(str);
return JITLambdaSupport_C{mlir::concretelang::JitLambdaSupport(opt)};
}
std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITLambdaSupport_C support, const char *module,
const char *funcname) {
mlir::concretelang::JitLambdaSupport esupport;
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
esupport.compile(module, funcname));
return std::move(*compilationResult);
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
support.support.loadClientParameters(result));
return *clientParameters;
}
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
jit_load_server_lambda(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
support.support.loadServerLambda(result));
return *serverLambda;
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITLambdaSupport_C support,
mlir::concretelang::JITLambda *lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult,
support.support.serverCall(lambda, args));
return std::move(*publicResult);
}
// Library Support bindings ///////////////////////////////////////////////////
MLIR_CAPI_EXPORTED LibraryLambdaSupport_C
library_lambda_support(const char *outputPath) {
return LibraryLambdaSupport_C{
mlir::concretelang::LibraryLambdaSupport(outputPath)};
}
std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibraryLambdaSupport_C support, const char *module,
const char *funcname) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, funcname));
return std::move(*compilationResult);
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
library_load_client_parameters(
LibraryLambdaSupport_C support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
support.support.loadClientParameters(result));
return *clientParameters;
}
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
library_load_server_lambda(
LibraryLambdaSupport_C support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
support.support.loadServerLambda(result));
return *serverLambda;
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibraryLambdaSupport_C support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult,
support.support.serverCall(lambda, args));
return std::move(*publicResult);
}
// Client Support bindings ///////////////////////////////////////////////////
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
key_set(concretelang::clientlib::ClientParameters clientParameters,
llvm::Optional<concretelang::clientlib::KeySetCache> cache) {
GET_OR_THROW_LLVM_EXPECTED(
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(clientParameters,
cache)));
return std::move(*ks);
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
GET_OR_THROW_LLVM_EXPECTED(
publicArguments,
(mlir::concretelang::LambdaSupport<int, int>::exportArguments(
clientParameters, keySet, args)));
return std::move(*publicArguments);
}
MLIR_CAPI_EXPORTED lambdaArgument
decrypt_result(concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult) {
GET_OR_THROW_LLVM_EXPECTED(
result, mlir::concretelang::typedResult<
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
keySet, publicResult));
lambdaArgument result_{std::move(*result)};
return std::move(result_);
}
mlir::concretelang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath, const char *keySetCachePath,
@@ -216,7 +341,7 @@ std::string library(std::string libraryPath,
using namespace mlir::concretelang;
JitCompilerEngine ce{CompilationContext::createShared()};
auto lib = ce.compile<std::string>(mlir_modules, libraryPath);
auto lib = ce.compile(mlir_modules, libraryPath);
if (!lib) {
throw std::runtime_error("Can't link: " + llvm::toString(lib.takeError()));
}

View File

@@ -4,12 +4,20 @@
// for license information.
#include <concretelang/Support/JitLambdaSupport.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
namespace mlir {
namespace concretelang {
JitLambdaSupport::JitLambdaSupport(
llvm::Optional<llvm::StringRef> runtimeLibPath,
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline)
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
llvm::Expected<std::unique_ptr<JitCompilationResult>>
JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
// Setup the compiler engine
auto context = std::make_shared<CompilationContext>();
concretelang::CompilerEngine engine(context);

View File

@@ -4,80 +4,116 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler import CompilerEngine, library
from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
from lib.Bindings.Python.concrete.compiler import ClientSupport
from lib.Bindings.Python.concrete.compiler import KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
[
pytest.param(
"""
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
def compile_and_run(engine, mlir_input, args, expected_result):
compilation_result = engine.compile(mlir_input)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
# Check result
assert type(expected_result) == type(result)
if isinstance(expected_result, int):
assert result == expected_result
else:
assert np.all(result == expected_result)
end_to_end_fixture = [
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(5, 7),
12,
id="add_eint_int",
),
pytest.param(
"""
(5, 7),
12,
id="add_eint_int",
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)),
9,
id="add_eint_int_with_ndarray_as_scalar",
),
pytest.param(
"""
(np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)),
9,
id="add_eint_int_with_ndarray_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(np.uint8(3), np.uint8(5)),
8,
id="add_eint_int_with_np_uint8_as_scalar",
),
pytest.param(
"""
(np.uint8(3), np.uint8(5)),
8,
id="add_eint_int_with_np_uint8_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(np.uint16(3), np.uint16(5)),
8,
id="add_eint_int_with_np_uint16_as_scalar",
),
pytest.param(
"""
(np.uint16(3), np.uint16(5)),
8,
id="add_eint_int_with_np_uint16_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(np.uint32(3), np.uint32(5)),
8,
id="add_eint_int_with_np_uint32_as_scalar",
),
pytest.param(
"""
(np.uint32(3), np.uint32(5)),
8,
id="add_eint_int_with_np_uint32_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(np.uint64(3), np.uint64(5)),
8,
id="add_eint_int_with_np_uint64_as_scalar",
),
pytest.param(
"""
(np.uint64(3), np.uint64(5)),
8,
id="add_eint_int_with_np_uint64_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
""",
(73,),
73,
id="apply_lookup_table",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
@@ -85,15 +121,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
return %ret : !FHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([4, 3, 2, 1], dtype=np.uint8),
),
20,
id="dot_eint_int_uint8",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([4, 3, 2, 1], dtype=np.uint8),
),
pytest.param(
"""
20,
id="dot_eint_int_uint8",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
@@ -101,15 +137,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
return %ret : !FHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint16),
np.array([4, 3, 2, 1], dtype=np.uint16),
),
20,
id="dot_eint_int_uint16",
(
np.array([1, 2, 3, 4], dtype=np.uint16),
np.array([4, 3, 2, 1], dtype=np.uint16),
),
pytest.param(
"""
20,
id="dot_eint_int_uint16",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
@@ -117,15 +153,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
return %ret : !FHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint32),
np.array([4, 3, 2, 1], dtype=np.uint32),
),
20,
id="dot_eint_int_uint32",
(
np.array([1, 2, 3, 4], dtype=np.uint32),
np.array([4, 3, 2, 1], dtype=np.uint32),
),
pytest.param(
"""
20,
id="dot_eint_int_uint32",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
@@ -133,91 +169,128 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
return %ret : !FHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint64),
np.array([4, 3, 2, 1], dtype=np.uint64),
),
20,
id="dot_eint_int_uint64",
(
np.array([1, 2, 3, 4], dtype=np.uint64),
np.array([4, 3, 2, 1], dtype=np.uint64),
),
pytest.param(
"""
20,
id="dot_eint_int_uint64",
),
pytest.param(
"""
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
""",
(
np.array([31, 6, 12, 9], dtype=np.uint8),
np.array([32, 9, 2, 3], dtype=np.uint8),
),
np.array([63, 15, 14, 12]),
id="add_eint_int_1D",
(
np.array([31, 6, 12, 9], dtype=np.uint8),
np.array([32, 9, 2, 3], dtype=np.uint8),
),
pytest.param(
"""
np.array([63, 15, 14, 12]),
id="add_eint_int_1D",
),
pytest.param(
"""
func @main(%a0: tensor<4x4x!FHE.eint<6>>, %a1: tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x4x!FHE.eint<6>>, tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>>
return %res : tensor<4x4x!FHE.eint<6>>
}
""",
(
np.array(
[[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]],
dtype=np.uint8,
),
np.array(
[[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]],
dtype=np.uint8,
),
),
(
np.array(
[
[63, 15, 14, 12],
[63, 15, 14, 12],
[63, 15, 14, 12],
[63, 15, 14, 12],
],
[[31, 6, 12, 9], [31, 6, 12, 9], [
31, 6, 12, 9], [31, 6, 12, 9]],
dtype=np.uint8,
),
np.array(
[[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]],
dtype=np.uint8,
),
id="add_eint_int_2D",
),
pytest.param(
"""
np.array(
[
[63, 15, 14, 12],
[63, 15, 14, 12],
[63, 15, 14, 12],
[63, 15, 14, 12],
],
dtype=np.uint8,
),
id="add_eint_int_2D",
),
pytest.param(
"""
func @main(%a0: tensor<2x2x2x!FHE.eint<6>>, %a1: tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x2x!FHE.eint<6>>, tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>>
return %res : tensor<2x2x2x!FHE.eint<6>>
}
""",
(
np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=np.uint8,
),
np.array(
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
dtype=np.uint8,
),
),
(
np.array(
[[[10, 12], [14, 16]], [[18, 20], [22, 24]]],
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=np.uint8,
),
np.array(
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
dtype=np.uint8,
),
id="add_eint_int_3D",
),
],
)
def test_compile_and_run(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
# numpy array
assert np.all(engine.run(*args) == expected_result)
np.array(
[[[10, 12], [14, 16]], [[18, 20], [22, 24]]],
dtype=np.uint8,
),
id="add_eint_int_3D",
),
]
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
def test_jit_compile_and_run(mlir_input, args, expected_result):
engine = JITCompilerSupport()
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
def test_lib_compile_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
# Check result
assert type(expected_result) == type(result)
if isinstance(expected_result, int):
assert result == expected_result
else:
assert np.all(result == expected_result)
@ pytest.mark.parametrize(
"mlir_input, args",
[
pytest.param(
@@ -234,8 +307,9 @@ def test_compile_and_run(mlir_input, args, expected_result):
)
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
with pytest.raises(ValueError, match=r"wrong number of arguments"):
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"):
engine.run(*args)
@@ -259,7 +333,8 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
)
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
engine = CompilerEngine()
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
@@ -281,8 +356,9 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
)
def test_compile_invalid(mlir_input):
engine = CompilerEngine()
with pytest.raises(RuntimeError, match=r"Compilation failed:"):
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
MODULE_1 = """

View File

@@ -14,7 +14,7 @@
\
auto desc = GetParam(); \
\
LambdaSupport support; \
auto support = LambdaSupport; \
\
/* 1 - Compile the program */ \
auto compilationResult = support.compile(desc.program); \
@@ -84,8 +84,10 @@
/// Instantiate the test suite for Jit
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
JitTest, mlir::concretelang::JitLambdaSupport)
JitTest, mlir::concretelang::JitLambdaSupport())
/// Instantiate the test suite for Jit
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
LibraryTest, mlir::concretelang::LibraryLambdaSupport)
LibraryTest,
mlir::concretelang::LibraryLambdaSupport("/tmp/end_to_end_test_" +
desc.description))