mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(python): Expose Jit and Library compiler support
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include "concretelang/Support/LibraryLambdaSupport.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
@@ -35,11 +37,79 @@ struct executionArguments {
|
||||
};
|
||||
typedef struct executionArguments executionArguments;
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
struct JITLambdaSupport_C {
|
||||
mlir::concretelang::JitLambdaSupport support;
|
||||
};
|
||||
typedef struct JITLambdaSupport_C JITLambdaSupport_C;
|
||||
|
||||
MLIR_CAPI_EXPORTED JITLambdaSupport_C
|
||||
jit_lambda_support(const char *runtimeLibPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITLambdaSupport_C support, const char *module,
|
||||
const char *funcname);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
|
||||
jit_load_server_lambda(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JITLambda *lambda,
|
||||
concretelang::clientlib::PublicArguments &args);
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
struct LibraryLambdaSupport_C {
|
||||
mlir::concretelang::LibraryLambdaSupport support;
|
||||
};
|
||||
typedef struct LibraryLambdaSupport_C LibraryLambdaSupport_C;
|
||||
|
||||
MLIR_CAPI_EXPORTED LibraryLambdaSupport_C
|
||||
library_lambda_support(const char *outputPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibraryLambdaSupport_C support, const char *module,
|
||||
const char *funcname);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(LibraryLambdaSupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(LibraryLambdaSupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibraryLambdaSupport_C support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args);
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
// Build lambda from a textual representation of an MLIR module
|
||||
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
|
||||
// null) as a shared library during compilation,
|
||||
// a path to activate the use a cache for encryption keys for test purpose
|
||||
// (unsecure), and a set of flags for parallelization.
|
||||
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if
|
||||
// not null) as a shared library during compilation, a path to activate the
|
||||
// use a cache for encryption keys for test purpose (unsecure).
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath, const char *keySetCachePath,
|
||||
|
||||
@@ -35,8 +35,7 @@ public:
|
||||
JitLambdaSupport(
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr))
|
||||
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr));
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, std::string funcname = "main") override;
|
||||
|
||||
@@ -33,8 +33,7 @@ class LibraryLambdaSupport
|
||||
: public LambdaSupport<serverlib::ServerLambda, LibraryCompilationResult> {
|
||||
|
||||
public:
|
||||
LibraryLambdaSupport(std::string outputPath = "/tmp/toto")
|
||||
: outputPath(outputPath) {}
|
||||
LibraryLambdaSupport(std::string outputPath) : outputPath(outputPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, std::string funcname = "main") override {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
@@ -20,6 +21,7 @@
|
||||
#include <string>
|
||||
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::JitLambdaSupport;
|
||||
using mlir::concretelang::LambdaArgument;
|
||||
|
||||
const char *noEmptyStringPtr(std::string &s) {
|
||||
@@ -54,6 +56,110 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
auto_parallelize, loop_parallelize,
|
||||
df_parallelize);
|
||||
});
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JitCompilationResult");
|
||||
pybind11::class_<mlir::concretelang::JITLambda>(m, "JITLambda");
|
||||
pybind11::class_<JITLambdaSupport_C>(m, "JITLambdaSupport")
|
||||
.def(pybind11::init([](std::string runtimeLibPath) {
|
||||
return jit_lambda_support(runtimeLibPath.c_str());
|
||||
}))
|
||||
.def("compile",
|
||||
[](JITLambdaSupport_C &support, std::string mlir_program,
|
||||
std::string func_name) {
|
||||
return jit_compile(support, mlir_program.c_str(),
|
||||
func_name.c_str());
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](JITLambdaSupport_C &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_client_parameters(support, result);
|
||||
})
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
[](JITLambdaSupport_C &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_server_lambda(support, result);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](JITLambdaSupport_C &support, concretelang::JITLambda *lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
return jit_server_call(support, lambda, publicArguments);
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
|
||||
m, "LibraryCompilationResult")
|
||||
.def(pybind11::init([](std::string libraryPath, std::string funcname) {
|
||||
return mlir::concretelang::LibraryCompilationResult{
|
||||
libraryPath,
|
||||
funcname,
|
||||
};
|
||||
}));
|
||||
pybind11::class_<concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
|
||||
pybind11::class_<LibraryLambdaSupport_C>(m, "LibraryLambdaSupport")
|
||||
.def(pybind11::init([](std::string outputPath) {
|
||||
return library_lambda_support(outputPath.c_str());
|
||||
}))
|
||||
.def("compile",
|
||||
[](LibraryLambdaSupport_C &support, std::string mlir_program,
|
||||
std::string func_name) {
|
||||
return library_compile(support, mlir_program.c_str(),
|
||||
func_name.c_str());
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](LibraryLambdaSupport_C &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
return library_load_client_parameters(support, result);
|
||||
})
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
[](LibraryLambdaSupport_C &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
return library_load_server_lambda(support, result);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](LibraryLambdaSupport_C &support, serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
return library_server_call(support, lambda, publicArguments);
|
||||
});
|
||||
|
||||
class ClientSupport {};
|
||||
pybind11::class_<ClientSupport>(m, "ClientSupport")
|
||||
.def(pybind11::init())
|
||||
.def_static(
|
||||
"key_set",
|
||||
[](clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySetCache *cache) {
|
||||
auto optCache =
|
||||
cache == nullptr
|
||||
? llvm::None
|
||||
: llvm::Optional<clientlib::KeySetCache>(*cache);
|
||||
return key_set(clientParameters, optCache);
|
||||
},
|
||||
pybind11::arg().none(false), pybind11::arg().none(true))
|
||||
.def_static("encrypt_arguments",
|
||||
[](clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySet &keySet,
|
||||
std::vector<lambdaArgument> args) {
|
||||
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
|
||||
for (auto i = 0u; i < args.size(); i++) {
|
||||
argsRef.push_back(args[i].ptr.get());
|
||||
}
|
||||
return encrypt_arguments(clientParameters, keySet, argsRef);
|
||||
})
|
||||
.def_static("decrypt_result", [](clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &publicResult) {
|
||||
return decrypt_result(keySet, publicResult);
|
||||
});
|
||||
pybind11::class_<KeySetCache>(m, "KeySetCache")
|
||||
.def(pybind11::init<std::string &>());
|
||||
|
||||
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters");
|
||||
|
||||
pybind11::class_<clientlib::KeySet>(m, "KeySet");
|
||||
pybind11::class_<clientlib::PublicArguments>(m, "PublicArguments");
|
||||
pybind11::class_<clientlib::PublicResult>(m, "PublicResult");
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor",
|
||||
|
||||
@@ -14,6 +14,23 @@ from mlir._mlir_libs._concretelang._compiler import (
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
from mlir._mlir_libs._concretelang._compiler import library as _library
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientParameters
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import KeySet
|
||||
from mlir._mlir_libs._concretelang._compiler import KeySetCache
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicResult
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicArguments
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambda
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -46,7 +63,8 @@ def _lookup_runtime_lib() -> str:
|
||||
for filename in os.listdir(libs_path)
|
||||
if filename.startswith("libConcretelangRuntime")
|
||||
]
|
||||
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
assert len(
|
||||
runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
return os.path.join(libs_path, runtime_library_paths[0])
|
||||
|
||||
|
||||
@@ -67,10 +85,10 @@ def round_trip(mlir_str: str) -> str:
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
|
||||
_MLIR_MODULES_TYPE = "mlir_modules must be an `iterable` of `str` or a `str"
|
||||
_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str'
|
||||
|
||||
|
||||
def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str:
|
||||
def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str:
|
||||
"""Compile the MLIR inputs to a library.
|
||||
|
||||
Args:
|
||||
@@ -101,7 +119,7 @@ def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str
|
||||
return _library(library_path, mlir_modules)
|
||||
|
||||
|
||||
def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument":
|
||||
def create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument:
|
||||
"""Create an execution argument holding either an int or tensor value.
|
||||
|
||||
Args:
|
||||
@@ -115,8 +133,7 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
|
||||
)
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
@@ -134,7 +151,7 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = _JitCompilerEngine()
|
||||
self._engine = JITCompilerSupport()
|
||||
self._lambda = None
|
||||
if mlir_str is not None:
|
||||
self.compile_fhe(mlir_str)
|
||||
@@ -182,17 +199,17 @@ class CompilerEngine:
|
||||
)
|
||||
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
|
||||
if not isinstance(unsecure_key_set_cache_path, str):
|
||||
raise TypeError("unsecure_key_set_cache_path must be a str")
|
||||
|
||||
self._lambda = self._engine.build_lambda(
|
||||
mlir_str,
|
||||
func_name,
|
||||
runtime_lib_path,
|
||||
unsecure_key_set_cache_path,
|
||||
auto_parallelize,
|
||||
loop_parallelize,
|
||||
df_parallelize,
|
||||
)
|
||||
raise TypeError(
|
||||
"unsecure_key_set_cache_path must be a str"
|
||||
)
|
||||
self._compilation_result = self._engine.compile(mlir_str)
|
||||
self._client_parameters = self._engine.load_client_parameters(
|
||||
self._compilation_result)
|
||||
keyset_cache = None
|
||||
if not unsecure_key_set_cache_path is None:
|
||||
keyset_cache = KeySetCache(unsecure_key_set_cache_path)
|
||||
self._key_set = ClientSupport.key_set(
|
||||
self._client_parameters, keyset_cache)
|
||||
|
||||
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
|
||||
"""Run the compiled code.
|
||||
@@ -208,10 +225,59 @@ class CompilerEngine:
|
||||
Returns:
|
||||
int or numpy.array: result of execution.
|
||||
"""
|
||||
if self._lambda is None:
|
||||
if self._compilation_result is None:
|
||||
raise RuntimeError("need to compile an MLIR code first")
|
||||
# Client
|
||||
public_arguments = ClientSupport.encrypt_arguments(self._client_parameters,
|
||||
self._key_set, args)
|
||||
# Server
|
||||
server_lambda = self._engine.load_server_lambda(
|
||||
self._compilation_result)
|
||||
public_result = self._engine.server_call(
|
||||
server_lambda, public_arguments)
|
||||
# Client
|
||||
return ClientSupport.decrypt_result(self._key_set, public_result)
|
||||
|
||||
|
||||
class ClientSupport:
|
||||
def key_set(client_parameters: ClientParameters, cache: KeySetCache = None) -> KeySet:
|
||||
"""Generates a key set according to the given client parameters.
|
||||
If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache
|
||||
|
||||
Args:
|
||||
client_parameters: A client parameters specification
|
||||
cache: An optional cache of key set.
|
||||
|
||||
Returns:
|
||||
KeySet: the key set
|
||||
"""
|
||||
return _ClientSupport.key_set(client_parameters, cache)
|
||||
|
||||
def encrypt_arguments(client_parameters: ClientParameters, key_set: KeySet, args: List[Union[int, np.ndarray]]) -> PublicArguments:
|
||||
"""Export clear arguments to public arguments.
|
||||
For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object.
|
||||
|
||||
Args:
|
||||
client_parameters: A client parameters specification
|
||||
key_set: A key set used to encrypt encrypted arguments
|
||||
|
||||
Returns:
|
||||
PublicArguments: the public arguments
|
||||
"""
|
||||
execution_arguments = [create_execution_argument(arg) for arg in args]
|
||||
lambda_arg = self._lambda.invoke(execution_arguments)
|
||||
return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments)
|
||||
|
||||
def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result thanks the given key set.
|
||||
|
||||
Args:
|
||||
key_set: The key set used to decrypt the result.
|
||||
public_result: The public result to descrypt.
|
||||
|
||||
Returns:
|
||||
int or numpy.array: The result of decryption.
|
||||
"""
|
||||
lambda_arg = _ClientSupport.decrypt_result(key_set, public_result)
|
||||
if lambda_arg.is_scalar():
|
||||
return lambda_arg.get_scalar()
|
||||
elif lambda_arg.is_tensor():
|
||||
@@ -220,3 +286,103 @@ class CompilerEngine:
|
||||
return tensor
|
||||
else:
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
|
||||
class JITCompilerSupport:
|
||||
def __init__(self, runtime_lib_path=None):
|
||||
if runtime_lib_path is None:
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
self._support = JITLambdaSupport(runtime_lib_path)
|
||||
|
||||
def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult:
|
||||
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
|
||||
|
||||
Args:
|
||||
mlir_program: A textual representation of the mlir program to compile.
|
||||
func_name: The name of the function to compile.
|
||||
|
||||
Returns:
|
||||
JITCompilationResult: the result of the JIT compilation.
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, func_name)
|
||||
|
||||
def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
return self._support.load_client_parameters(compilation_result)
|
||||
|
||||
def load_server_lambda(self, compilation_result: JitCompilationResult) -> JITLambda:
|
||||
"""Load the server lambda from the JIT compilation result"""
|
||||
return self._support.load_server_lambda(compilation_result)
|
||||
|
||||
def server_call(self, server_lambda: JITLambda, public_arguments: PublicArguments):
|
||||
"""Call the server lambda with public_arguments
|
||||
|
||||
Args:
|
||||
server_lambda: A server lambda to call
|
||||
public_arguments: The arguments of the call
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda
|
||||
"""
|
||||
return self._support.server_call(server_lambda, public_arguments)
|
||||
|
||||
|
||||
class LibraryCompilerSupport:
|
||||
def __init__(self, outputPath="./out"):
|
||||
self._library_path = outputPath
|
||||
self._support = LibraryLambdaSupport(outputPath)
|
||||
|
||||
def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult:
|
||||
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
|
||||
|
||||
Args:
|
||||
mlir_program: A textual representation of the mlir program to compile.
|
||||
func_name: The name of the function to compile.
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: the result of the compilation.
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, func_name)
|
||||
|
||||
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
|
||||
"""Reload the library compilation result from the outputPath.
|
||||
Args:
|
||||
library-path: The path of the compiled library.
|
||||
func_name: The name of the compiled function.
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: the result of a compilation.
|
||||
"""
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError("func_name must be an `str`")
|
||||
return LibraryCompilationResult(self._library_path, func_name)
|
||||
|
||||
def load_client_parameters(self, compilation_result: LibraryCompilationResult) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
if not isinstance(compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
"compilation_result must be an `LibraryCompilationResult`")
|
||||
|
||||
return self._support.load_client_parameters(compilation_result)
|
||||
|
||||
def load_server_lambda(self, compilation_result: LibraryCompilationResult) -> LibraryLambda:
|
||||
"""Load the server lambda from the JIT compilation result"""
|
||||
return self._support.load_server_lambda(compilation_result)
|
||||
|
||||
def server_call(self, server_lambda: LibraryLambda, public_arguments: PublicArguments) -> PublicResult:
|
||||
"""Call the server lambda with public_arguments
|
||||
|
||||
Args:
|
||||
server_lambda: A server lambda to call
|
||||
public_arguments: The arguments of the call
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda
|
||||
"""
|
||||
return self._support.server_call(server_lambda, public_arguments)
|
||||
|
||||
@@ -11,9 +11,134 @@
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
|
||||
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
|
||||
auto VARNAME = EXPECTED; \
|
||||
if (auto err = VARNAME.takeError()) { \
|
||||
throw std::runtime_error(llvm::toString(std::move(err))); \
|
||||
}
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED JITLambdaSupport_C
|
||||
jit_lambda_support(const char *runtimeLibPath) {
|
||||
llvm::StringRef str(runtimeLibPath);
|
||||
auto opt = str.empty() ? llvm::None : llvm::Optional<llvm::StringRef>(str);
|
||||
return JITLambdaSupport_C{mlir::concretelang::JitLambdaSupport(opt)};
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITLambdaSupport_C support, const char *module,
|
||||
const char *funcname) {
|
||||
mlir::concretelang::JitLambdaSupport esupport;
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
esupport.compile(module, funcname));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
|
||||
support.support.loadClientParameters(result));
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
|
||||
jit_load_server_lambda(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
support.support.loadServerLambda(result));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JITLambda *lambda,
|
||||
concretelang::clientlib::PublicArguments &args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult,
|
||||
support.support.serverCall(lambda, args));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
MLIR_CAPI_EXPORTED LibraryLambdaSupport_C
|
||||
library_lambda_support(const char *outputPath) {
|
||||
return LibraryLambdaSupport_C{
|
||||
mlir::concretelang::LibraryLambdaSupport(outputPath)};
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibraryLambdaSupport_C support, const char *module,
|
||||
const char *funcname) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, funcname));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(
|
||||
LibraryLambdaSupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
|
||||
support.support.loadClientParameters(result));
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(
|
||||
LibraryLambdaSupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
support.support.loadServerLambda(result));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibraryLambdaSupport_C support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult,
|
||||
support.support.serverCall(lambda, args));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(clientParameters,
|
||||
cache)));
|
||||
return std::move(*ks);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
publicArguments,
|
||||
(mlir::concretelang::LambdaSupport<int, int>::exportArguments(
|
||||
clientParameters, keySet, args)));
|
||||
return std::move(*publicArguments);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
result, mlir::concretelang::typedResult<
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
keySet, publicResult));
|
||||
lambdaArgument result_{std::move(*result)};
|
||||
return std::move(result_);
|
||||
}
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath, const char *keySetCachePath,
|
||||
@@ -216,7 +341,7 @@ std::string library(std::string libraryPath,
|
||||
using namespace mlir::concretelang;
|
||||
|
||||
JitCompilerEngine ce{CompilationContext::createShared()};
|
||||
auto lib = ce.compile<std::string>(mlir_modules, libraryPath);
|
||||
auto lib = ce.compile(mlir_modules, libraryPath);
|
||||
if (!lib) {
|
||||
throw std::runtime_error("Can't link: " + llvm::toString(lib.takeError()));
|
||||
}
|
||||
|
||||
@@ -4,12 +4,20 @@
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/Support/JitLambdaSupport.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
JitLambdaSupport::JitLambdaSupport(
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline)
|
||||
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
|
||||
|
||||
// Setup the compiler engine
|
||||
auto context = std::make_shared<CompilationContext>();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
|
||||
@@ -4,80 +4,116 @@ import tempfile
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine, library
|
||||
from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
|
||||
from lib.Bindings.Python.concrete.compiler import ClientSupport
|
||||
from lib.Bindings.Python.concrete.compiler import KeySetCache
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
"""
|
||||
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
def compile_and_run(engine, mlir_input, args, expected_result):
|
||||
compilation_result = engine.compile(mlir_input)
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
# Check result
|
||||
assert type(expected_result) == type(result)
|
||||
if isinstance(expected_result, int):
|
||||
assert result == expected_result
|
||||
else:
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
end_to_end_fixture = [
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
id="add_eint_int",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
(5, 7),
|
||||
12,
|
||||
id="add_eint_int",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)),
|
||||
9,
|
||||
id="add_eint_int_with_ndarray_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
(np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)),
|
||||
9,
|
||||
id="add_eint_int_with_ndarray_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint8(3), np.uint8(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint8_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
(np.uint8(3), np.uint8(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint8_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint16(3), np.uint16(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint16_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
(np.uint16(3), np.uint16(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint16_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint32(3), np.uint32(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint32_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
(np.uint32(3), np.uint32(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint32_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint64(3), np.uint64(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint64_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
(np.uint64(3), np.uint64(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint64_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(73,),
|
||||
73,
|
||||
id="apply_lookup_table",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
@@ -85,15 +121,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint8",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
20,
|
||||
id="dot_eint_int_uint8",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
@@ -101,15 +137,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint16),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint16),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint16",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint16),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint16),
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
20,
|
||||
id="dot_eint_int_uint16",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
@@ -117,15 +153,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint32),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint32),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint32",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint32),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint32),
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
20,
|
||||
id="dot_eint_int_uint32",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
@@ -133,91 +169,128 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint64),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint64),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint64",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint64),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint64),
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
20,
|
||||
id="dot_eint_int_uint64",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([31, 6, 12, 9], dtype=np.uint8),
|
||||
np.array([32, 9, 2, 3], dtype=np.uint8),
|
||||
),
|
||||
np.array([63, 15, 14, 12]),
|
||||
id="add_eint_int_1D",
|
||||
(
|
||||
np.array([31, 6, 12, 9], dtype=np.uint8),
|
||||
np.array([32, 9, 2, 3], dtype=np.uint8),
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
np.array([63, 15, 14, 12]),
|
||||
id="add_eint_int_1D",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%a0: tensor<4x4x!FHE.eint<6>>, %a1: tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x4x!FHE.eint<6>>, tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>>
|
||||
return %res : tensor<4x4x!FHE.eint<6>>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array(
|
||||
[[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.array(
|
||||
[[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
),
|
||||
(
|
||||
np.array(
|
||||
[
|
||||
[63, 15, 14, 12],
|
||||
[63, 15, 14, 12],
|
||||
[63, 15, 14, 12],
|
||||
[63, 15, 14, 12],
|
||||
],
|
||||
[[31, 6, 12, 9], [31, 6, 12, 9], [
|
||||
31, 6, 12, 9], [31, 6, 12, 9]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.array(
|
||||
[[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
id="add_eint_int_2D",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
np.array(
|
||||
[
|
||||
[63, 15, 14, 12],
|
||||
[63, 15, 14, 12],
|
||||
[63, 15, 14, 12],
|
||||
[63, 15, 14, 12],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
id="add_eint_int_2D",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%a0: tensor<2x2x2x!FHE.eint<6>>, %a1: tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x2x!FHE.eint<6>>, tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>>
|
||||
return %res : tensor<2x2x2x!FHE.eint<6>>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array(
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.array(
|
||||
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
),
|
||||
(
|
||||
np.array(
|
||||
[[[10, 12], [14, 16]], [[18, 20], [22, 24]]],
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.array(
|
||||
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
id="add_eint_int_3D",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
# numpy array
|
||||
assert np.all(engine.run(*args) == expected_result)
|
||||
np.array(
|
||||
[[[10, 12], [14, 16]], [[18, 20], [22, 24]]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
id="add_eint_int_3D",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
def test_jit_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = JITCompilerSupport()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
def test_lib_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
|
||||
# Here don't save compilation result, reload
|
||||
engine.compile(mlir_input)
|
||||
compilation_result = engine.reload()
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
# Check result
|
||||
assert type(expected_result) == type(result)
|
||||
if isinstance(expected_result, int):
|
||||
assert result == expected_result
|
||||
else:
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
"mlir_input, args",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -234,8 +307,9 @@ def test_compile_and_run(mlir_input, args, expected_result):
|
||||
)
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(ValueError, match=r"wrong number of arguments"):
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
@@ -259,7 +333,8 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
)
|
||||
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
|
||||
|
||||
@@ -281,8 +356,9 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = CompilerEngine()
|
||||
with pytest.raises(RuntimeError, match=r"Compilation failed:"):
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
MODULE_1 = """
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
\
|
||||
auto desc = GetParam(); \
|
||||
\
|
||||
LambdaSupport support; \
|
||||
auto support = LambdaSupport; \
|
||||
\
|
||||
/* 1 - Compile the program */ \
|
||||
auto compilationResult = support.compile(desc.program); \
|
||||
@@ -84,8 +84,10 @@
|
||||
|
||||
/// Instantiate the test suite for Jit
|
||||
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
|
||||
JitTest, mlir::concretelang::JitLambdaSupport)
|
||||
JitTest, mlir::concretelang::JitLambdaSupport())
|
||||
|
||||
/// Instantiate the test suite for Jit
|
||||
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
|
||||
LibraryTest, mlir::concretelang::LibraryLambdaSupport)
|
||||
LibraryTest,
|
||||
mlir::concretelang::LibraryLambdaSupport("/tmp/end_to_end_test_" +
|
||||
desc.description))
|
||||
Reference in New Issue
Block a user