refactor: recreate the Python API using wrappers

All Cpp objects are now wrapped, and calls are being forwarded after
strict type checking for avoiding weird behaviors.
The same Cpp API is now exposed to Python
This commit is contained in:
youben11
2022-03-31 10:51:04 +01:00
committed by Ayoub Benaissa
parent 308504566f
commit 999ab4e5ea
20 changed files with 1304 additions and 387 deletions

View File

@@ -26,7 +26,23 @@ declare_mlir_python_extension(ConcretelangBindingsPythonExtension.Core
declare_mlir_python_sources(ConcretelangBindingsPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
SOURCES
concrete/compiler.py
concrete/compiler/__init__.py
concrete/compiler/client_parameters.py
concrete/compiler/client_support.py
concrete/compiler/compilation_options.py
concrete/compiler/jit_compilation_result.py
concrete/compiler/jit_lambda_support.py
concrete/compiler/jit_lambda.py
concrete/compiler/key_set_cache.py
concrete/compiler/key_set.py
concrete/compiler/lambda_argument.py
concrete/compiler/library_compilation_result.py
concrete/compiler/library_lambda_support.py
concrete/compiler/library_lambda.py
concrete/compiler/public_arguments.py
concrete/compiler/public_result.py
concrete/compiler/utils.py
concrete/compiler/wrapper.py
concrete/__init__.py
concrete/lang/__init__.py
concrete/lang/dialects/__init__.py

View File

@@ -44,16 +44,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CompilationOptions &options, bool b) {
options.verifyDiagnostics = b;
})
.def("auto_parallelize", [](CompilationOptions &options,
bool b) { options.autoParallelize = b; })
.def("loop_parallelize", [](CompilationOptions &options,
bool b) { options.loopParallelize = b; })
.def("dataflow_parallelize", [](CompilationOptions &options, bool b) {
.def("set_auto_parallelize", [](CompilationOptions &options,
bool b) { options.autoParallelize = b; })
.def("set_loop_parallelize", [](CompilationOptions &options,
bool b) { options.loopParallelize = b; })
.def("set_dataflow_parallelize", [](CompilationOptions &options, bool b) {
options.dataflowParallelize = b;
});
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JitCompilationResult");
m, "JITCompilationResult");
pybind11::class_<mlir::concretelang::JITLambda,
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
"JITLambda");

View File

@@ -1,380 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""Compiler submodule"""
from collections.abc import Iterable
import os
import atexit
from typing import List, Union
from mlir._mlir_libs._concretelang._compiler import (
terminate_parallelization as _terminate_parallelization,
)
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
from mlir._mlir_libs._concretelang._compiler import ClientParameters
from mlir._mlir_libs._concretelang._compiler import KeySet
from mlir._mlir_libs._concretelang._compiler import KeySetCache
from mlir._mlir_libs._concretelang._compiler import PublicResult
from mlir._mlir_libs._concretelang._compiler import PublicArguments
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
from mlir._mlir_libs._concretelang._compiler import CompilationOptions
from mlir._mlir_libs._concretelang._compiler import (
JITLambdaSupport as _JITLambdaSupport,
)
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
from mlir._mlir_libs._concretelang._compiler import JITLambda
from mlir._mlir_libs._concretelang._compiler import (
LibraryLambdaSupport as _LibraryLambdaSupport,
)
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
import numpy as np
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
# Terminate parallelization in the compiler (if init) during cleanup
atexit.register(_terminate_parallelization)
def _lookup_runtime_lib() -> str:
"""Try to find the absolute path to the runtime library.
Returns:
str: absolute path to the runtime library, or empty str if unsuccessful.
"""
# Go up to site-packages level
cwd = os.path.abspath(__file__)
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
package_name = "concrete_compiler"
libs_path = os.path.join(cwd, f"{package_name}.libs")
# Can be because it's not a properly installed package
if not os.path.exists(libs_path):
return ""
runtime_library_paths = [
filename
for filename in os.listdir(libs_path)
if filename.startswith("libConcretelangRuntime")
]
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
return os.path.join(libs_path, runtime_library_paths[0])
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
Args:
mlir_str (str): MLIR code to parse.
Raises:
TypeError: if the argument is not an str.
Returns:
str: parsed MLIR input.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return _round_trip(mlir_str)
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = JITCompilerSupport()
self._lambda = None
if mlir_str is not None:
self.compile_fhe(mlir_str)
def compile_fhe(
self,
mlir_str: str,
func_name: str = "main",
unsecure_key_set_cache_path: str = None,
auto_parallelize: bool = False,
loop_parallelize: bool = False,
df_parallelize: bool = False,
):
"""Compile the MLIR input.
Args:
mlir_str (str): MLIR to compile.
func_name (str): name of the function to set as entrypoint (default: main).
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
auto_parallelize (bool): whether to activate auto-parallelization or not (default: False),
loop_parallelize (bool): whether to activate loop-parallelization or not (default: False),
df_parallelize (bool): whether to activate dataflow-parallelization or not (default: False),
Raises:
TypeError: if the argument is not an str.
"""
if not all(
isinstance(flag, bool)
for flag in [auto_parallelize, loop_parallelize, df_parallelize]
):
raise TypeError(
"parallelization flags (auto_parallelize, loop_parallelize, df_parallelize), should be booleans"
)
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
if not isinstance(unsecure_key_set_cache_path, str):
raise TypeError("unsecure_key_set_cache_path must be a str")
options = CompilationOptions(func_name)
options.auto_parallelize(auto_parallelize)
options.loop_parallelize(loop_parallelize)
options.dataflow_parallelize(df_parallelize)
self._compilation_result = self._engine.compile(mlir_str, options)
self._client_parameters = self._engine.load_client_parameters(
self._compilation_result
)
keyset_cache = None
if not unsecure_key_set_cache_path is None:
keyset_cache = KeySetCache(unsecure_key_set_cache_path)
self._key_set = ClientSupport.key_set(self._client_parameters, keyset_cache)
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
"""Run the compiled code.
Args:
*args: list of arguments for execution. Each argument can be an int, or a numpy.array
Raises:
TypeError: if execution arguments can't be constructed
RuntimeError: if the engine has not compiled any code yet
RuntimeError: if the return type is unknown
Returns:
int or numpy.array: result of execution.
"""
if self._compilation_result is None:
raise RuntimeError("need to compile an MLIR code first")
# Client
public_arguments = ClientSupport.encrypt_arguments(
self._client_parameters, self._key_set, args
)
# Server
server_lambda = self._engine.load_server_lambda(self._compilation_result)
public_result = self._engine.server_call(server_lambda, public_arguments)
# Client
return ClientSupport.decrypt_result(self._key_set, public_result)
class ClientSupport:
def key_set(
client_parameters: ClientParameters, cache: KeySetCache = None
) -> KeySet:
"""Generates a key set according to the given client parameters.
If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache
Args:
client_parameters: A client parameters specification
cache: An optional cache of key set.
Returns:
KeySet: the key set
"""
return _ClientSupport.key_set(client_parameters, cache)
def encrypt_arguments(
client_parameters: ClientParameters,
key_set: KeySet,
args: List[Union[int, np.ndarray]],
) -> PublicArguments:
"""Export clear arguments to public arguments.
For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object.
Args:
client_parameters: A client parameters specification
key_set: A key set used to encrypt encrypted arguments
Returns:
PublicArguments: the public arguments
"""
execution_arguments = [
ClientSupport._create_execution_argument(arg) for arg in args
]
return _ClientSupport.encrypt_arguments(
client_parameters, key_set, execution_arguments
)
def decrypt_result(
key_set: KeySet, public_result: PublicResult
) -> Union[int, np.ndarray]:
"""Decrypt a public result thanks the given key set.
Args:
key_set: The key set used to decrypt the result.
public_result: The public result to descrypt.
Returns:
int or numpy.array: The result of decryption.
"""
lambda_arg = _ClientSupport.decrypt_result(key_set, public_result)
if lambda_arg.is_scalar():
return lambda_arg.get_scalar()
elif lambda_arg.is_tensor():
shape = lambda_arg.get_tensor_shape()
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
return tensor
else:
raise RuntimeError("unknown return type")
def _create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument:
"""Create an execution argument holding either an int or tensor value.
Args:
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
Raises:
TypeError: if the values aren't in the expected range, or using a wrong type
Returns:
_LambdaArgument: lambda argument holding the appropriate value
"""
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError(
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
)
if isinstance(value, ACCEPTED_INTS):
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
raise TypeError(
"single integer must be in the range [0, 2**64 - 1] (uint64)"
)
return _LambdaArgument.from_scalar(value)
else:
assert isinstance(value, np.ndarray)
if value.shape == ():
return _LambdaArgument.from_scalar(value)
if value.dtype not in ACCEPTED_NUMPY_UINTS:
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
class JITCompilerSupport:
def __init__(self, runtime_lib_path=None):
if runtime_lib_path is None:
runtime_lib_path = _lookup_runtime_lib()
else:
if not isinstance(runtime_lib_path, str):
raise TypeError(
"runtime_lib_path must be an str representing the path to the runtime lib"
)
self._support = _JITLambdaSupport(runtime_lib_path)
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions("main"),
) -> JitCompilationResult:
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
Args:
mlir_program: A textual representation of the mlir program to compile.
func_name: The name of the function to compile.
Returns:
JITCompilationResult: the result of the JIT compilation.
"""
if not isinstance(mlir_program, str):
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, options)
def load_client_parameters(
self, compilation_result: JitCompilationResult
) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
return self._support.load_client_parameters(compilation_result)
def load_server_lambda(self, compilation_result: JitCompilationResult) -> JITLambda:
"""Load the server lambda from the JIT compilation result"""
return self._support.load_server_lambda(compilation_result)
def server_call(self, server_lambda: JITLambda, public_arguments: PublicArguments):
"""Call the server lambda with public_arguments
Args:
server_lambda: A server lambda to call
public_arguments: The arguments of the call
Returns:
PublicResult: the result of the call of the server lambda
"""
return self._support.server_call(server_lambda, public_arguments)
class LibraryCompilerSupport:
def __init__(self, outputPath="./out"):
self._library_path = outputPath
self._support = _LibraryLambdaSupport(outputPath)
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions("main"),
) -> LibraryCompilationResult:
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
Args:
mlir_program: A textual representation of the mlir program to compile.
func_name: The name of the function to compile.
Returns:
LibraryCompilationResult: the result of the compilation.
"""
if not isinstance(mlir_program, str):
raise TypeError("mlir_program must be an `str`")
if not isinstance(options, CompilationOptions):
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, options)
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
"""Reload the library compilation result from the outputPath.
Args:
library-path: The path of the compiled library.
func_name: The name of the compiled function.
Returns:
LibraryCompilationResult: the result of a compilation.
"""
if not isinstance(func_name, str):
raise TypeError("func_name must be an `str`")
return LibraryCompilationResult(self._library_path, func_name)
def load_client_parameters(
self, compilation_result: LibraryCompilationResult
) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError("compilation_result must be an `LibraryCompilationResult`")
return self._support.load_client_parameters(compilation_result)
def load_server_lambda(
self, compilation_result: LibraryCompilationResult
) -> LibraryLambda:
"""Load the server lambda from the JIT compilation result"""
return self._support.load_server_lambda(compilation_result)
def server_call(
self, server_lambda: LibraryLambda, public_arguments: PublicArguments
) -> PublicResult:
"""Call the server lambda with public_arguments
Args:
server_lambda: A server lambda to call
public_arguments: The arguments of the call
Returns:
PublicResult: the result of the call of the server lambda
"""
return self._support.server_call(server_lambda, public_arguments)

View File

@@ -0,0 +1,48 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""Compiler submodule."""
import atexit
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
terminate_parallelization as _terminate_parallelization,
)
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
# pylint: enable=no-name-in-module,import-error
from .compilation_options import CompilationOptions
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters
from .key_set import KeySet
from .public_result import PublicResult
from .public_arguments import PublicArguments
from .jit_compilation_result import JITCompilationResult
from .jit_lambda import JITLambda
from .client_support import ClientSupport
from .jit_lambda_support import JITLambdaSupport
from .library_lambda_support import LibraryLambdaSupport
# Terminate parallelization in the compiler (if init) during cleanup
atexit.register(_terminate_parallelization)
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
Useful to check the validity of an MLIR representation
Args:
mlir_str (str): textual representation of an MLIR code
Raises:
TypeError: if mlir_str is not of type str
Returns:
str: textual representation of the MLIR code after parsing
"""
if not isinstance(mlir_str, str):
raise TypeError(f"mlir_str must be of type str, not {type(mlir_str)}")
return _round_trip(mlir_str)

View File

@@ -0,0 +1,36 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""Client parameters."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
ClientParameters as _ClientParameters,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class ClientParameters(WrapperCpp):
"""ClientParameters are public parameters used for key generation.
It's a compilation artifact that describes which and how public and private keys should be generated,
and used to encrypt arguments of the compiled function.
"""
def __init__(self, client_parameters: _ClientParameters):
"""Wrap the native Cpp object.
Args:
client_parameters (_ClientParameters): object to wrap
Raises:
TypeError: if client_parameters is not of type _ClientParameters
"""
if not isinstance(client_parameters, _ClientParameters):
raise TypeError(
f"client_parameters must be of type _ClientParameters, not {type(client_parameters)}"
)
super().__init__(client_parameters)

View File

@@ -0,0 +1,191 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""Client support."""
from typing import List, Optional, Union
import numpy as np
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
# pylint: enable=no-name-in-module,import-error
from .public_result import PublicResult
from .key_set import KeySet
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters
from .public_arguments import PublicArguments
from .lambda_argument import LambdaArgument
from .wrapper import WrapperCpp
from .utils import ACCEPTED_INTS, ACCEPTED_NUMPY_UINTS, ACCEPTED_TYPES
class ClientSupport(WrapperCpp):
"""Client interface for doing key generation and encryption.
It provides features that are needed on the client side:
- Generation of public and private keys required for the encrypted computation
- Encryption and preparation of public arguments, used later as input to the computation
- Decryption of public result returned after execution
"""
def __init__(self, client_support: _ClientSupport):
"""Wrap the native Cpp object.
Args:
client_support (_ClientSupport): object to wrap
Raises:
TypeError: if client_support is not of type _ClientSupport
"""
if not isinstance(client_support, _ClientSupport):
raise TypeError(
f"client_support must be of type _ClientSupport not {type(client_support)}"
)
super().__init__(client_support)
@staticmethod
# pylint: disable=arguments-differ
def new() -> "ClientSupport":
"""Build a ClientSupport.
Returns:
ClientSupport
"""
return ClientSupport.wrap(_ClientSupport())
# pylint: enable=arguments-differ
@staticmethod
def key_set(
client_parameters: ClientParameters, keyset_cache: Optional[KeySetCache] = None
) -> KeySet:
"""Generate a key set according to the client parameters.
If the cache is set, and include equivalent keys as specified by the client parameters, the keyset
is loaded, otherwise, a new keyset is generated and saved in the cache.
Args:
client_parameters (ClientParameters): client parameters specification
keyset_cache (Optional[KeySetCache], optional): keyset cache. Defaults to None.
Raises:
TypeError: if client_parameters is not of type ClientParameters
TypeError: if keyset_cache is not of type KeySetCache
Returns:
KeySet: generated or loaded keyset
"""
if keyset_cache is not None and not isinstance(keyset_cache, KeySetCache):
raise TypeError(
f"keyset_cache must be None or of type KeySetCache, not {type(keyset_cache)}"
)
cpp_cache = None if keyset_cache is None else keyset_cache.cpp()
return KeySet.wrap(_ClientSupport.key_set(client_parameters.cpp(), cpp_cache))
@staticmethod
def encrypt_arguments(
client_parameters: ClientParameters,
keyset: KeySet,
args: List[Union[int, np.ndarray]],
) -> PublicArguments:
"""Prepare arguments for encrypted computation.
Pack public arguments by encrypting the ones that requires encryption, and leaving the rest as plain.
It also pack public materials (public keys) that are required during the computation.
Args:
client_parameters (ClientParameters): client parameters specification
keyset (KeySet): keyset used to encrypt arguments that require encryption
args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments
Raises:
TypeError: if client_parameters is not of type ClientParameters
TypeError: if keyset is not of type KeySet
Returns:
PublicArguments: public arguments for execution
"""
if not isinstance(client_parameters, ClientParameters):
raise TypeError(
f"client_parameters must be of type ClientParameters, not {type(client_parameters)}"
)
if not isinstance(keyset, KeySet):
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
lambda_arguments = [ClientSupport._create_lambda_argument(arg) for arg in args]
return PublicArguments.wrap(
_ClientSupport.encrypt_arguments(
client_parameters.cpp(),
keyset.cpp(),
[arg.cpp() for arg in lambda_arguments],
)
)
@staticmethod
def decrypt_result(
keyset: KeySet, public_result: PublicResult
) -> Union[int, np.ndarray]:
"""Decrypt a public result using the keyset.
Args:
keyset (KeySet): keyset used for decryption
public_result: public result to decrypt
Raises:
TypeError: if keyset is not of type KeySet
TypeError: if public_result is not of type PublicResult
RuntimeError: if the result is of an unknown type
Returns:
Union[int, np.ndarray]: plain result
"""
if not isinstance(keyset, KeySet):
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
if not isinstance(public_result, PublicResult):
raise TypeError(
f"public_result must be of type PublicResult, not {type(public_result)}"
)
lambda_arg = LambdaArgument.wrap(
_ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp())
)
if lambda_arg.is_scalar():
return lambda_arg.get_scalar()
if lambda_arg.is_tensor():
shape = lambda_arg.get_tensor_shape()
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
return tensor
raise RuntimeError("unknown return type")
@staticmethod
def _create_lambda_argument(value: Union[int, np.ndarray]) -> LambdaArgument:
"""Create a lambda argument holding either an int or tensor value.
Args:
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
Raises:
TypeError: if the values aren't in the expected range, or using a wrong type
Returns:
LambdaArgument: lambda argument holding the appropriate value
"""
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError(
"value of lambda argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
)
if isinstance(value, ACCEPTED_INTS):
if isinstance(value, int) and not 0 <= value < np.iinfo(np.uint64).max:
raise TypeError(
"single integer must be in the range [0, 2**64 - 1] (uint64)"
)
return LambdaArgument.from_scalar(value)
assert isinstance(value, np.ndarray)
if value.dtype not in ACCEPTED_NUMPY_UINTS:
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
if value.shape == ():
if isinstance(value, np.ndarray):
# extract the single element
value = value.max()
# should be a single uint here
return LambdaArgument.from_scalar(value)
return LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)

View File

@@ -0,0 +1,122 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""CompilationOptions."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
CompilationOptions as _CompilationOptions,
)
from .wrapper import WrapperCpp
# pylint: enable=no-name-in-module,import-error
class CompilationOptions(WrapperCpp):
"""CompilationOptions holds different flags and options of the compilation process.
It controls different parallelization flags, diagnostic verification, and also the name of entrypoint
function.
"""
def __init__(self, compilation_options: _CompilationOptions):
"""Wrap the native Cpp object.
Args:
compilation_options (_CompilationOptions): object to wrap
Raises:
TypeError: if compilation_options is not of type _CompilationOptions
"""
if not isinstance(compilation_options, _CompilationOptions):
raise TypeError(
f"_compilation_options must be of type _CompilationOptions, not {type(compilation_options)}"
)
super().__init__(compilation_options)
@staticmethod
# pylint: disable=arguments-differ
def new(function_name="main") -> "CompilationOptions":
"""Build a CompilationOptions.
Args:
function_name (str, optional): name of the entrypoint function. Defaults to "main".
Raises:
TypeError: if function_name is not an str
Returns:
CompilationOptions
"""
if not isinstance(function_name, str):
raise TypeError(
f"function_name must be of type str not {type(function_name)}"
)
return CompilationOptions.wrap(_CompilationOptions(function_name))
# pylint: enable=arguments-differ
def set_auto_parallelize(self, auto_parallelize: bool):
"""Set option for auto parallelization.
Args:
auto_parallelize (bool): whether to turn it on or off
Raises:
TypeError: if the value to set is not boolean
"""
if not isinstance(auto_parallelize, bool):
raise TypeError("can't set the option to a non-boolean value")
self.cpp().set_auto_parallelize(auto_parallelize)
def set_loop_parallelize(self, loop_parallelize: bool):
"""Set option for loop parallelization.
Args:
loop_parallelize (bool): whether to turn it on or off
Raises:
TypeError: if the value to set is not boolean
"""
if not isinstance(loop_parallelize, bool):
raise TypeError("can't set the option to a non-boolean value")
self.cpp().set_loop_parallelize(loop_parallelize)
def set_verify_diagnostics(self, verify_diagnostics: bool):
"""Set option for diagnostics verification.
Args:
verify_diagnostics (bool): whether to turn it on or off
Raises:
TypeError: if the value to set is not boolean
"""
if not isinstance(verify_diagnostics, bool):
raise TypeError("can't set the option to a non-boolean value")
self.cpp().set_verify_diagnostics(verify_diagnostics)
def set_dataflow_parallelize(self, dataflow_parallelize: bool):
"""Set option for dataflow parallelization.
Args:
dataflow_parallelize (bool): whether to turn it on or off
Raises:
TypeError: if the value to set is not boolean
"""
if not isinstance(dataflow_parallelize, bool):
raise TypeError("can't set the option to a non-boolean value")
self.cpp().set_dataflow_parallelize(dataflow_parallelize)
def set_funcname(self, funcname: str):
"""Set entrypoint function name.
Args:
funcname (str): name of the entrypoint function
Raises:
TypeError: if the value to set is not str
"""
if not isinstance(funcname, str):
raise TypeError("can't set the option to a non-str value")
self.cpp().set_funcname(funcname)

View File

@@ -0,0 +1,38 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""JITCompilationResult."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
JITCompilationResult as _JITCompilationResult,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class JITCompilationResult(WrapperCpp):
"""JITCompilationResult holds the result of a JIT compilation.
It can be instrumented using the JITLambdaSupport to load client parameters and execute the compiled
code.
"""
def __init__(self, jit_compilation_result: _JITCompilationResult):
"""Wrap the native Cpp object.
Args:
jit_compilation_result (_JITCompilationResult): object to wrap
Raises:
TypeError: if jit_compilation_result is not of type _JITCompilationResult
"""
if not isinstance(jit_compilation_result, _JITCompilationResult):
raise TypeError(
f"jit_compilation_result must be of type _JITCompilationResult, not "
f"{type(jit_compilation_result)}"
)
super().__init__(jit_compilation_result)

View File

@@ -0,0 +1,35 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""JITLambda."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
JITLambda as _JITLambda,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class JITLambda(WrapperCpp):
"""JITLambda contains an in-memory executable code and can be ran using JITLambdaSupport.
It's an artifact of JIT compilation, which stays in memory and can be executed with the help of
JITLambdaSupport.
"""
def __init__(self, jit_lambda: _JITLambda):
"""Wrap the native Cpp object.
Args:
jit_lambda (_JITLambda): object to wrap
Raises:
TypeError: if jit_lambda is not of type JITLambda
"""
if not isinstance(jit_lambda, _JITLambda):
raise TypeError(
f"jit_lambda must be of type _JITLambda, not {type(jit_lambda)}"
)
super().__init__(jit_lambda)

View File

@@ -0,0 +1,167 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""JITLambdaSupport.
Just-in-time compilation provide a way to compile and execute an MLIR program while keeping the executable
code in memory.
"""
from typing import Optional
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
JITLambdaSupport as _JITLambdaSupport,
)
# pylint: enable=no-name-in-module,import-error
from .utils import lookup_runtime_lib
from .compilation_options import CompilationOptions
from .jit_compilation_result import JITCompilationResult
from .client_parameters import ClientParameters
from .jit_lambda import JITLambda
from .public_arguments import PublicArguments
from .public_result import PublicResult
from .wrapper import WrapperCpp
class JITLambdaSupport(WrapperCpp):
"""Support class for JIT compilation and execution."""
def __init__(self, jit_lambda_support: _JITLambdaSupport):
"""Wrap the native Cpp object.
Args:
jit_lambda_support (_JITLambdaSupport): object to wrap
Raises:
TypeError: if jit_lambda_support is not of type _JITLambdaSupport
"""
if not isinstance(jit_lambda_support, _JITLambdaSupport):
raise TypeError(
f"jit_lambda_support must be of type _JITLambdaSupport not{type(jit_lambda_support)}"
)
super().__init__(jit_lambda_support)
@staticmethod
# pylint: disable=arguments-differ
def new(runtime_lib_path: Optional[str] = None) -> "JITLambdaSupport":
"""Build a JITLambdaSupport.
Args:
runtime_lib_path (Optional[str]): path to the runtime library. Defaults to None.
Raises:
TypeError: if runtime_lib_path is not of type str or None
Returns:
JITLambdaSupport
"""
if runtime_lib_path is None:
runtime_lib_path = lookup_runtime_lib()
else:
if not isinstance(runtime_lib_path, str):
raise TypeError(
f"runtime_lib_path must be of type str, not {type(runtime_lib_path)}"
)
return JITLambdaSupport.wrap(_JITLambdaSupport(runtime_lib_path))
# pylint: enable=arguments-differ
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions.new("main"),
) -> JITCompilationResult:
"""JIT compile an MLIR program using Concrete dialects.
Args:
mlir_program (str): textual representation of the mlir program to compile
options (CompilationOptions): compilation options
Raises:
TypeError: if mlir_program is not of type str
TypeError: if options is not of type CompilationOptions
Returns:
JITCompilationResult: the result of the JIT compilation
"""
if not isinstance(mlir_program, str):
raise TypeError(
f"mlir_program must be of type str, not {type(mlir_program)}"
)
if not isinstance(options, CompilationOptions):
raise TypeError(
f"options must be of type CompilationOptions, not {type(options)}"
)
return JITCompilationResult.wrap(
self.cpp().compile(mlir_program, options.cpp())
)
def load_client_parameters(
self, compilation_result: JITCompilationResult
) -> ClientParameters:
"""Load the client parameters from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
ClientParameters: appropriate client parameters for the compiled program
"""
if not isinstance(compilation_result, JITCompilationResult):
raise TypeError(
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
)
return ClientParameters.wrap(
self.cpp().load_client_parameters(compilation_result.cpp())
)
def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda:
"""Load the JITLambda from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation.
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
JITLambda: loaded JITLambda to be executed
"""
if not isinstance(compilation_result, JITCompilationResult):
raise TypeError(
f"compilation_result must be a JITCompilationResult not {type(compilation_result)}"
)
return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp()))
def server_call(
self, jit_lambda: JITLambda, public_arguments: PublicArguments
) -> PublicResult:
"""Call the JITLambda with public_arguments.
Args:
jit_lambda (JITLambda): A server lambda to call.
public_arguments (PublicArguments): The arguments of the call.
Raises:
TypeError: if jit_lambda is not of type JITLambda
TypeError: if public_arguments is not of type PublicArguments
Returns:
PublicResult: the result of the call of the server lambda.
"""
if not isinstance(jit_lambda, JITLambda):
raise TypeError(
f"jit_lambda must be of type JITLambda, not {type(jit_lambda)}"
)
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
return PublicResult.wrap(
self.cpp().server_call(jit_lambda.cpp(), public_arguments.cpp())
)

View File

@@ -0,0 +1,36 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""KeySet.
Store for the different keys required for an encrypted computation.
"""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
KeySet as _KeySet,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class KeySet(WrapperCpp):
"""KeySet stores the different keys required for an encrypted computation.
Holds private keys (secret key) used for encryption/decryption, and public keys used for computation.
"""
def __init__(self, keyset: _KeySet):
"""Wrap the native Cpp object.
Args:
keyset (_KeySet): object to wrap
Raises:
TypeError: if keyset is not of type _KeySet
"""
if not isinstance(keyset, _KeySet):
raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}")
super().__init__(keyset)

View File

@@ -0,0 +1,59 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""KeySetCache.
Cache for keys to avoid generating similar keys multiple times.
"""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
KeySetCache as _KeySetCache,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class KeySetCache(WrapperCpp):
"""KeySetCache is a cache for KeySet to avoid generating similar keys multiple times.
Keys get cached and can be later used instead of generating a new keyset which can take a lot of time.
"""
def __init__(self, keyset_cache: _KeySetCache):
"""Wrap the native Cpp object.
Args:
keyset_cache (_KeySetCache): object to wrap
Raises:
TypeError: if keyset_cache is not of type _KeySetCache
"""
if not isinstance(keyset_cache, _KeySetCache):
raise TypeError(
f"key_set_cache must be of type _KeySetCache, not {type(keyset_cache)}"
)
super().__init__(keyset_cache)
@staticmethod
# pylint: disable=arguments-differ
def new(cache_path: str) -> "KeySetCache":
"""Build a KeySetCache located at cache_path.
Args:
cache_path (str): path to the cache
Raises:
TypeError: if the path is not of type str.
Returns:
KeySetCache
"""
if not isinstance(cache_path, str):
raise TypeError(
f"cache_path must to be of type str, not {type(cache_path)}"
)
return KeySetCache.wrap(_KeySetCache(cache_path))
# pylint: enable=arguments-differ

View File

@@ -0,0 +1,116 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""LambdaArgument."""
from typing import List
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
LambdaArgument as _LambdaArgument,
)
# pylint: enable=no-name-in-module,import-error
from .utils import ACCEPTED_INTS
from .wrapper import WrapperCpp
class LambdaArgument(WrapperCpp):
"""LambdaArgument holds scalar or tensor values."""
def __init__(self, lambda_argument: _LambdaArgument):
"""Wrap the native Cpp object.
Args:
lambda_argument (_LambdaArgument): object to wrap
Raises:
TypeError: if lambda_argument is not of type _LambdaArgument
"""
if not isinstance(lambda_argument, _LambdaArgument):
raise TypeError(
f"lambda_argument must be of type _LambdaArgument, not {type(lambda_argument)}"
)
super().__init__(lambda_argument)
@staticmethod
def new(*args, **kwargs):
"""Use from_scalar or from_tensor instead.
Raises:
RuntimeError
"""
raise RuntimeError(
"you should call from_scalar or from_tensor according to the argument type"
)
@staticmethod
def from_scalar(scalar: int) -> "LambdaArgument":
"""Build a LambdaArgument containing the given scalar value.
Args:
scalar (int or numpy.uint): scalar value to embed in LambdaArgument
Raises:
TypeError: if scalar is not of type int or numpy.uint
Returns:
LambdaArgument
"""
if not isinstance(scalar, ACCEPTED_INTS):
raise TypeError(
f"scalar must be of type int or numpy.uint, not {type(scalar)}"
)
return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar))
@staticmethod
def from_tensor(data: List[int], shape: List[int]) -> "LambdaArgument":
"""Build a LambdaArgument containing the given tensor.
Args:
data (List[int]): flattened tensor data
shape (List[int]): shape of original tensor before flattening
Returns:
LambdaArgument
"""
return LambdaArgument.wrap(_LambdaArgument.from_tensor(data, shape))
def is_scalar(self) -> bool:
"""Check if the contained argument is a scalar.
Returns:
bool
"""
return self.cpp().is_scalar()
def get_scalar(self) -> int:
"""Return the contained scalar value.
Returns:
int
"""
return self.cpp().get_scalar()
def is_tensor(self) -> bool:
"""Check if the contained argument is a tensor.
Returns:
bool
"""
return self.cpp().is_tensor()
def get_tensor_shape(self) -> List[int]:
"""Return the shape of the contained tensor.
Returns:
List[int]: tensor shape
"""
return self.cpp().get_tensor_shape()
def get_tensor_data(self) -> List[int]:
"""Return the contained flattened tensor data.
Returns:
List[int]
"""
return self.cpp().get_tensor_data()

View File

@@ -0,0 +1,60 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""LibraryCompilationResult."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
LibraryCompilationResult as _LibraryCompilationResult,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class LibraryCompilationResult(WrapperCpp):
"""LibraryCompilationResult holds the result of the library compilation."""
def __init__(self, library_compilation_result: _LibraryCompilationResult):
"""Wrap the native Cpp object.
Args:
library_compilation_result (_LibraryCompilationResult): object to wrap
Raises:
TypeError: if library_compilation_result is not of type _LibraryCompilationResult
"""
if not isinstance(library_compilation_result, _LibraryCompilationResult):
raise TypeError(
f"library_compilation_result must be of type _LibraryCompilationResult, not "
f"{type(library_compilation_result)}"
)
super().__init__(library_compilation_result)
@staticmethod
# pylint: disable=arguments-differ
def new(library_path: str, func_name: str) -> "LibraryCompilationResult":
"""Build a LibraryCompilationResult at library_path, with func_name as entrypoint.
Args:
library_path (str): path to the library
func_name (str): entrypoint function name
Raises:
TypeError: if library_path is not of type str
TypeError: if func_name is not of type str
Returns:
LibraryCompilationResult
"""
if not isinstance(library_path, str):
raise TypeError(
f"library_path must be of type str, not {type(library_path)}"
)
if not isinstance(func_name, str):
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
return LibraryCompilationResult.wrap(
_LibraryCompilationResult(library_path, func_name)
)
# pylint: enable=arguments-differ

View File

@@ -0,0 +1,31 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""LibraryLambda."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
LibraryLambda as _LibraryLambda,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class LibraryLambda(WrapperCpp):
"""LibraryLambda reference a compiled library and can be ran using LibraryLambdaSupport."""
def __init__(self, library_lambda: _LibraryLambda):
"""Wrap the native Cpp object.
Args:
library_lambda (_LibraryLambda): object to wrap
Raises:
TypeError: if library_lambda is not of type _LibraryLambda
"""
if not isinstance(library_lambda, _LibraryLambda):
raise TypeError(
f"library_lambda must be of type _LibraryLambda, not {type(library_lambda)}"
)
super().__init__(library_lambda)

View File

@@ -0,0 +1,202 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""LibraryCompilerSupport.
Library compilation provides a way to compile an MLIR program into a library that can be later loaded
to execute the compiled code.
"""
import os
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
LibraryLambdaSupport as _LibraryLambdaSupport,
)
# pylint: enable=no-name-in-module,import-error
from .compilation_options import CompilationOptions
from .library_compilation_result import LibraryCompilationResult
from .public_arguments import PublicArguments
from .library_lambda import LibraryLambda
from .public_result import PublicResult
from .client_parameters import ClientParameters
from .wrapper import WrapperCpp
# Default output path for compiled libraries
DEFAULT_OUTPUT_PATH = os.path.abspath(
os.path.join(os.path.curdir, "concrete-compiler_output_lib")
)
class LibraryLambdaSupport(WrapperCpp):
"""Support class for library compilation and execution."""
def __init__(self, library_lambda_support: _LibraryLambdaSupport):
"""Wrap the native Cpp object.
Args:
library_lambda_support (_LibraryLambdaSupport): object to wrap
Raises:
TypeError: if library_lambda_support is not of type _LibraryLambdaSupport
"""
if not isinstance(library_lambda_support, _LibraryLambdaSupport):
raise TypeError(
f"library_lambda_support must be of type _LibraryLambdaSupport, not "
f"{type(library_lambda_support)}"
)
super().__init__(library_lambda_support)
self.library_path = DEFAULT_OUTPUT_PATH
@property
def library_path(self) -> str:
"""Path where to store compiled libraries."""
return self._library_path
@library_path.setter
def library_path(self, path: str):
if not isinstance(path, str):
raise TypeError(f"path must be of type str, not {type(path)}")
self._library_path = path
@staticmethod
# pylint: disable=arguments-differ
def new(output_path: str = DEFAULT_OUTPUT_PATH) -> "LibraryLambdaSupport":
"""Build a LibraryLambdaSupport.
Args:
output_path (str, optional): path where to store compiled libraries.
Defaults to DEFAULT_OUTPUT_PATH.
Raises:
TypeError: if output_path is not of type str
Returns:
LibraryLambdaSupport
"""
if not isinstance(output_path, str):
raise TypeError(f"output_path must be of type str, not {type(output_path)}")
library_lambda_support = LibraryLambdaSupport.wrap(
_LibraryLambdaSupport(output_path)
)
library_lambda_support.library_path = output_path
return library_lambda_support
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions.new("main"),
) -> LibraryCompilationResult:
"""Compile an MLIR program using Concrete dialects into a library.
Args:
mlir_program (str): textual representation of the mlir program to compile
options (CompilationOptions): compilation options
Raises:
TypeError: if mlir_program is not of type str
TypeError: if options is not of type CompilationOptions
Returns:
LibraryCompilationResult: the result of the library compilation
"""
if not isinstance(mlir_program, str):
raise TypeError(
f"mlir_program must be of type str, not {type(mlir_program)}"
)
if not isinstance(options, CompilationOptions):
raise TypeError(
f"options must be of type CompilationOptions, not {type(options)}"
)
return LibraryCompilationResult.wrap(
self.cpp().compile(mlir_program, options.cpp())
)
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
"""Reload the library compilation result from the library_path.
Args:
func_name: entrypoint function name
Returns:
LibraryCompilationResult: loaded library
"""
if not isinstance(func_name, str):
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
return LibraryCompilationResult.new(self.library_path, func_name)
def load_client_parameters(
self, library_compilation_result: LibraryCompilationResult
) -> ClientParameters:
"""Load the client parameters from the library compilation result.
Args:
library_compilation_result (LibraryCompilationResult): compilation result of the library
Raises:
TypeError: if library_compilation_result is not of type LibraryCompilationResult
Returns:
ClientParameters: appropriate client parameters for the compiled library
"""
if not isinstance(library_compilation_result, LibraryCompilationResult):
raise TypeError(
f"library_compilation_result must be of type LibraryCompilationResult, not "
f"{type(library_compilation_result)}"
)
return ClientParameters.wrap(
self.cpp().load_client_parameters(library_compilation_result.cpp())
)
def load_server_lambda(
self, library_compilation_result: LibraryCompilationResult
) -> LibraryLambda:
"""Load the server lambda from the library compilation result.
Args:
library_compilation_result (LibraryCompilationResult): compilation result of the library
Raises:
TypeError: if library_compilation_result is not of type LibraryCompilationResult
Returns:
LibraryLambda: executable reference to the library
"""
if not isinstance(library_compilation_result, LibraryCompilationResult):
raise TypeError(
f"library_compilation_result must be of type LibraryCompilationResult, not "
f"{type(library_compilation_result)}"
)
return LibraryLambda.wrap(
self.cpp().load_server_lambda(library_compilation_result.cpp())
)
def server_call(
self, library_lambda: LibraryLambda, public_arguments: PublicArguments
) -> PublicResult:
"""Call the library with public_arguments.
Args:
library_lambda (LibraryLambda): reference to the compiled library
public_arguments (PublicArguments): arguments to use for execution
Raises:
TypeError: if library_lambda is not of type LibraryLambda
TypeError: if public_arguments is not of type PublicArguments
Returns:
PublicResult: result of the execution
"""
if not isinstance(library_lambda, LibraryLambda):
raise TypeError(
f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}"
)
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
return PublicResult.wrap(
self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp())
)

View File

@@ -0,0 +1,35 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""PublicArguments."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
PublicArguments as _PublicArguments,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class PublicArguments(WrapperCpp):
"""PublicArguments holds encrypted and plain arguments, as well as public materials.
An encrypted computation may require both encrypted and plain arguments, PublicArguments holds both
types, but also other public materials, such as public keys, which are required for private computation.
"""
def __init__(self, public_arguments: _PublicArguments):
"""Wrap the native Cpp object.
Args:
public_arguments (_PublicArguments): object to wrap
Raises:
TypeError: if public_arguments is not of type _PublicArguments
"""
if not isinstance(public_arguments, _PublicArguments):
raise TypeError(
f"public_arguments must be of type _PublicArguments, not {type(public_arguments)}"
)
super().__init__(public_arguments)

View File

@@ -0,0 +1,31 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""PublicResult."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
PublicResult as _PublicResult,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class PublicResult(WrapperCpp):
"""PublicResult holds the result of an encrypted execution and can be decrypted using ClientSupport."""
def __init__(self, public_result: _PublicResult):
"""Wrap the native Cpp object.
Args:
public_result (_PublicResult): object to wrap
Raises:
TypeError: if public_result is not of type _PublicResult
"""
if not isinstance(public_result, _PublicResult):
raise TypeError(
f"public_result must be of type _PublicResult, not {type(public_result)}"
)
super().__init__(public_result)

View File

@@ -0,0 +1,35 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""Common utils for the compiler submodule."""
import os
import numpy as np
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
def lookup_runtime_lib() -> str:
"""Try to find the absolute path to the runtime library.
Returns:
str: absolute path to the runtime library, or empty str if unsuccessful.
"""
# Go up to site-packages level
cwd = os.path.abspath(__file__)
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
package_name = "concrete_compiler"
libs_path = os.path.join(cwd, f"{package_name}.libs")
# Can be because it's not a properly installed package
if not os.path.exists(libs_path):
return ""
runtime_library_paths = [
filename
for filename in os.listdir(libs_path)
if filename.startswith("libConcretelangRuntime")
]
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
return os.path.join(libs_path, runtime_library_paths[0])

View File

@@ -0,0 +1,39 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""Wrapper for native Cpp objects."""
class WrapperCpp:
"""Wrapper base class for native Cpp objects.
Initialization should mainly store the wrapped object, and future calls to the wrapper will be forwarded
to it. A static wrap method is provided to be more explicit. Wrappers should always be constructed using
the new method, which construct the Cpp object using the provided arguments, then wrap it. Classes that
inherit from this class should preferably type check the wrapped object during calls to init, and
reimplement the new method if the class is meant to be constructed.
"""
def __init__(self, cpp_obj):
self._cpp_obj = cpp_obj
@classmethod
def wrap(cls, cpp_obj) -> "WrapperCpp":
"""Wrap the Cpp object into a Python object.
Args:
cpp_obj: object to wrap
Returns:
WrapperCpp: wrapper
"""
return cls(cpp_obj)
@staticmethod
def new(*args, **kwargs):
"""Create a new wrapper by building the underlying object with a specific set of arguments."""
raise RuntimeError("This class shouldn't be built")
def cpp(self):
"""Return the Cpp wrapped object."""
return self._cpp_obj