diff --git a/compiler/lib/Bindings/Python/CMakeLists.txt b/compiler/lib/Bindings/Python/CMakeLists.txt index 3d75a5089..57605d8bb 100644 --- a/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compiler/lib/Bindings/Python/CMakeLists.txt @@ -26,7 +26,23 @@ declare_mlir_python_extension(ConcretelangBindingsPythonExtension.Core declare_mlir_python_sources(ConcretelangBindingsPythonSources ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" SOURCES - concrete/compiler.py + concrete/compiler/__init__.py + concrete/compiler/client_parameters.py + concrete/compiler/client_support.py + concrete/compiler/compilation_options.py + concrete/compiler/jit_compilation_result.py + concrete/compiler/jit_lambda_support.py + concrete/compiler/jit_lambda.py + concrete/compiler/key_set_cache.py + concrete/compiler/key_set.py + concrete/compiler/lambda_argument.py + concrete/compiler/library_compilation_result.py + concrete/compiler/library_lambda_support.py + concrete/compiler/library_lambda.py + concrete/compiler/public_arguments.py + concrete/compiler/public_result.py + concrete/compiler/utils.py + concrete/compiler/wrapper.py concrete/__init__.py concrete/lang/__init__.py concrete/lang/dialects/__init__.py diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index d3af32d43..5593244b4 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -44,16 +44,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](CompilationOptions &options, bool b) { options.verifyDiagnostics = b; }) - .def("auto_parallelize", [](CompilationOptions &options, - bool b) { options.autoParallelize = b; }) - .def("loop_parallelize", [](CompilationOptions &options, - bool b) { options.loopParallelize = b; }) - .def("dataflow_parallelize", [](CompilationOptions &options, bool b) { + .def("set_auto_parallelize", [](CompilationOptions &options, + bool b) { options.autoParallelize = b; }) + .def("set_loop_parallelize", [](CompilationOptions &options, + bool b) { options.loopParallelize = b; }) + .def("set_dataflow_parallelize", [](CompilationOptions &options, bool b) { options.dataflowParallelize = b; }); pybind11::class_( - m, "JitCompilationResult"); + m, "JITCompilationResult"); pybind11::class_>(m, "JITLambda"); diff --git a/compiler/lib/Bindings/Python/concrete/compiler.py b/compiler/lib/Bindings/Python/concrete/compiler.py deleted file mode 100644 index fa38ce653..000000000 --- a/compiler/lib/Bindings/Python/concrete/compiler.py +++ /dev/null @@ -1,380 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. - -"""Compiler submodule""" -from collections.abc import Iterable -import os -import atexit -from typing import List, Union - -from mlir._mlir_libs._concretelang._compiler import ( - terminate_parallelization as _terminate_parallelization, -) - -from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip - -from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport - -from mlir._mlir_libs._concretelang._compiler import ClientParameters - -from mlir._mlir_libs._concretelang._compiler import KeySet -from mlir._mlir_libs._concretelang._compiler import KeySetCache - -from mlir._mlir_libs._concretelang._compiler import PublicResult -from mlir._mlir_libs._concretelang._compiler import PublicArguments -from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument - -from mlir._mlir_libs._concretelang._compiler import CompilationOptions - -from mlir._mlir_libs._concretelang._compiler import ( - JITLambdaSupport as _JITLambdaSupport, -) -from mlir._mlir_libs._concretelang._compiler import JitCompilationResult -from mlir._mlir_libs._concretelang._compiler import JITLambda - -from mlir._mlir_libs._concretelang._compiler import ( - LibraryLambdaSupport as _LibraryLambdaSupport, -) -from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult -from mlir._mlir_libs._concretelang._compiler import LibraryLambda -import numpy as np - - -ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64) -ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS -ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS - - -# Terminate parallelization in the compiler (if init) during cleanup -atexit.register(_terminate_parallelization) - - -def _lookup_runtime_lib() -> str: - """Try to find the absolute path to the runtime library. - - Returns: - str: absolute path to the runtime library, or empty str if unsuccessful. - """ - # Go up to site-packages level - cwd = os.path.abspath(__file__) - cwd = os.path.abspath(os.path.join(cwd, os.pardir)) - cwd = os.path.abspath(os.path.join(cwd, os.pardir)) - package_name = "concrete_compiler" - libs_path = os.path.join(cwd, f"{package_name}.libs") - # Can be because it's not a properly installed package - if not os.path.exists(libs_path): - return "" - runtime_library_paths = [ - filename - for filename in os.listdir(libs_path) - if filename.startswith("libConcretelangRuntime") - ] - assert len(runtime_library_paths) == 1, "should be one and only one runtime library" - return os.path.join(libs_path, runtime_library_paths[0]) - - -def round_trip(mlir_str: str) -> str: - """Parse the MLIR input, then return it back. - - Args: - mlir_str (str): MLIR code to parse. - - Raises: - TypeError: if the argument is not an str. - - Returns: - str: parsed MLIR input. - """ - if not isinstance(mlir_str, str): - raise TypeError("input must be an `str`") - return _round_trip(mlir_str) - - -class CompilerEngine: - def __init__(self, mlir_str: str = None): - self._engine = JITCompilerSupport() - self._lambda = None - if mlir_str is not None: - self.compile_fhe(mlir_str) - - def compile_fhe( - self, - mlir_str: str, - func_name: str = "main", - unsecure_key_set_cache_path: str = None, - auto_parallelize: bool = False, - loop_parallelize: bool = False, - df_parallelize: bool = False, - ): - """Compile the MLIR input. - - Args: - mlir_str (str): MLIR to compile. - func_name (str): name of the function to set as entrypoint (default: main). - unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None). - auto_parallelize (bool): whether to activate auto-parallelization or not (default: False), - loop_parallelize (bool): whether to activate loop-parallelization or not (default: False), - df_parallelize (bool): whether to activate dataflow-parallelization or not (default: False), - - Raises: - TypeError: if the argument is not an str. - """ - if not all( - isinstance(flag, bool) - for flag in [auto_parallelize, loop_parallelize, df_parallelize] - ): - raise TypeError( - "parallelization flags (auto_parallelize, loop_parallelize, df_parallelize), should be booleans" - ) - unsecure_key_set_cache_path = unsecure_key_set_cache_path or "" - if not isinstance(unsecure_key_set_cache_path, str): - raise TypeError("unsecure_key_set_cache_path must be a str") - options = CompilationOptions(func_name) - options.auto_parallelize(auto_parallelize) - options.loop_parallelize(loop_parallelize) - options.dataflow_parallelize(df_parallelize) - self._compilation_result = self._engine.compile(mlir_str, options) - self._client_parameters = self._engine.load_client_parameters( - self._compilation_result - ) - keyset_cache = None - if not unsecure_key_set_cache_path is None: - keyset_cache = KeySetCache(unsecure_key_set_cache_path) - self._key_set = ClientSupport.key_set(self._client_parameters, keyset_cache) - - def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]: - """Run the compiled code. - - Args: - *args: list of arguments for execution. Each argument can be an int, or a numpy.array - - Raises: - TypeError: if execution arguments can't be constructed - RuntimeError: if the engine has not compiled any code yet - RuntimeError: if the return type is unknown - - Returns: - int or numpy.array: result of execution. - """ - if self._compilation_result is None: - raise RuntimeError("need to compile an MLIR code first") - # Client - public_arguments = ClientSupport.encrypt_arguments( - self._client_parameters, self._key_set, args - ) - # Server - server_lambda = self._engine.load_server_lambda(self._compilation_result) - public_result = self._engine.server_call(server_lambda, public_arguments) - # Client - return ClientSupport.decrypt_result(self._key_set, public_result) - - -class ClientSupport: - def key_set( - client_parameters: ClientParameters, cache: KeySetCache = None - ) -> KeySet: - """Generates a key set according to the given client parameters. - If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache - - Args: - client_parameters: A client parameters specification - cache: An optional cache of key set. - - Returns: - KeySet: the key set - """ - return _ClientSupport.key_set(client_parameters, cache) - - def encrypt_arguments( - client_parameters: ClientParameters, - key_set: KeySet, - args: List[Union[int, np.ndarray]], - ) -> PublicArguments: - """Export clear arguments to public arguments. - For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object. - - Args: - client_parameters: A client parameters specification - key_set: A key set used to encrypt encrypted arguments - - Returns: - PublicArguments: the public arguments - """ - execution_arguments = [ - ClientSupport._create_execution_argument(arg) for arg in args - ] - return _ClientSupport.encrypt_arguments( - client_parameters, key_set, execution_arguments - ) - - def decrypt_result( - key_set: KeySet, public_result: PublicResult - ) -> Union[int, np.ndarray]: - """Decrypt a public result thanks the given key set. - - Args: - key_set: The key set used to decrypt the result. - public_result: The public result to descrypt. - - Returns: - int or numpy.array: The result of decryption. - """ - lambda_arg = _ClientSupport.decrypt_result(key_set, public_result) - if lambda_arg.is_scalar(): - return lambda_arg.get_scalar() - elif lambda_arg.is_tensor(): - shape = lambda_arg.get_tensor_shape() - tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape) - return tensor - else: - raise RuntimeError("unknown return type") - - def _create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument: - """Create an execution argument holding either an int or tensor value. - - Args: - value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array - - Raises: - TypeError: if the values aren't in the expected range, or using a wrong type - - Returns: - _LambdaArgument: lambda argument holding the appropriate value - """ - if not isinstance(value, ACCEPTED_TYPES): - raise TypeError( - "value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}" - ) - if isinstance(value, ACCEPTED_INTS): - if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max): - raise TypeError( - "single integer must be in the range [0, 2**64 - 1] (uint64)" - ) - return _LambdaArgument.from_scalar(value) - else: - assert isinstance(value, np.ndarray) - if value.shape == (): - return _LambdaArgument.from_scalar(value) - if value.dtype not in ACCEPTED_NUMPY_UINTS: - raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") - return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape) - - -class JITCompilerSupport: - def __init__(self, runtime_lib_path=None): - if runtime_lib_path is None: - runtime_lib_path = _lookup_runtime_lib() - else: - if not isinstance(runtime_lib_path, str): - raise TypeError( - "runtime_lib_path must be an str representing the path to the runtime lib" - ) - self._support = _JITLambdaSupport(runtime_lib_path) - - def compile( - self, - mlir_program: str, - options: CompilationOptions = CompilationOptions("main"), - ) -> JitCompilationResult: - """JIT Compile a function define in the mlir_program to its homomorphic equivalent. - - Args: - mlir_program: A textual representation of the mlir program to compile. - func_name: The name of the function to compile. - - Returns: - JITCompilationResult: the result of the JIT compilation. - """ - if not isinstance(mlir_program, str): - raise TypeError("mlir_program must be an `str`") - return self._support.compile(mlir_program, options) - - def load_client_parameters( - self, compilation_result: JitCompilationResult - ) -> ClientParameters: - """Load the client parameters from the JIT compilation result""" - return self._support.load_client_parameters(compilation_result) - - def load_server_lambda(self, compilation_result: JitCompilationResult) -> JITLambda: - """Load the server lambda from the JIT compilation result""" - return self._support.load_server_lambda(compilation_result) - - def server_call(self, server_lambda: JITLambda, public_arguments: PublicArguments): - """Call the server lambda with public_arguments - - Args: - server_lambda: A server lambda to call - public_arguments: The arguments of the call - - Returns: - PublicResult: the result of the call of the server lambda - """ - return self._support.server_call(server_lambda, public_arguments) - - -class LibraryCompilerSupport: - def __init__(self, outputPath="./out"): - self._library_path = outputPath - self._support = _LibraryLambdaSupport(outputPath) - - def compile( - self, - mlir_program: str, - options: CompilationOptions = CompilationOptions("main"), - ) -> LibraryCompilationResult: - """Compile a function define in the mlir_program to its homomorphic equivalent and save as library. - - Args: - mlir_program: A textual representation of the mlir program to compile. - func_name: The name of the function to compile. - - Returns: - LibraryCompilationResult: the result of the compilation. - """ - if not isinstance(mlir_program, str): - raise TypeError("mlir_program must be an `str`") - if not isinstance(options, CompilationOptions): - raise TypeError("mlir_program must be an `str`") - return self._support.compile(mlir_program, options) - - def reload(self, func_name: str = "main") -> LibraryCompilationResult: - """Reload the library compilation result from the outputPath. - Args: - library-path: The path of the compiled library. - func_name: The name of the compiled function. - - Returns: - LibraryCompilationResult: the result of a compilation. - """ - if not isinstance(func_name, str): - raise TypeError("func_name must be an `str`") - return LibraryCompilationResult(self._library_path, func_name) - - def load_client_parameters( - self, compilation_result: LibraryCompilationResult - ) -> ClientParameters: - """Load the client parameters from the JIT compilation result""" - if not isinstance(compilation_result, LibraryCompilationResult): - raise TypeError("compilation_result must be an `LibraryCompilationResult`") - - return self._support.load_client_parameters(compilation_result) - - def load_server_lambda( - self, compilation_result: LibraryCompilationResult - ) -> LibraryLambda: - """Load the server lambda from the JIT compilation result""" - return self._support.load_server_lambda(compilation_result) - - def server_call( - self, server_lambda: LibraryLambda, public_arguments: PublicArguments - ) -> PublicResult: - """Call the server lambda with public_arguments - - Args: - server_lambda: A server lambda to call - public_arguments: The arguments of the call - - Returns: - PublicResult: the result of the call of the server lambda - """ - return self._support.server_call(server_lambda, public_arguments) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py new file mode 100644 index 000000000..95b25c5c6 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -0,0 +1,48 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""Compiler submodule.""" +import atexit + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + terminate_parallelization as _terminate_parallelization, +) +from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip + +# pylint: enable=no-name-in-module,import-error + +from .compilation_options import CompilationOptions +from .key_set_cache import KeySetCache +from .client_parameters import ClientParameters +from .key_set import KeySet +from .public_result import PublicResult +from .public_arguments import PublicArguments +from .jit_compilation_result import JITCompilationResult +from .jit_lambda import JITLambda +from .client_support import ClientSupport +from .jit_lambda_support import JITLambdaSupport +from .library_lambda_support import LibraryLambdaSupport + + +# Terminate parallelization in the compiler (if init) during cleanup +atexit.register(_terminate_parallelization) + + +def round_trip(mlir_str: str) -> str: + """Parse the MLIR input, then return it back. + + Useful to check the validity of an MLIR representation + + Args: + mlir_str (str): textual representation of an MLIR code + + Raises: + TypeError: if mlir_str is not of type str + + Returns: + str: textual representation of the MLIR code after parsing + """ + if not isinstance(mlir_str, str): + raise TypeError(f"mlir_str must be of type str, not {type(mlir_str)}") + return _round_trip(mlir_str) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py b/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py new file mode 100644 index 000000000..cffbaa97a --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py @@ -0,0 +1,36 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""Client parameters.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + ClientParameters as _ClientParameters, +) + +# pylint: enable=no-name-in-module,import-error + +from .wrapper import WrapperCpp + + +class ClientParameters(WrapperCpp): + """ClientParameters are public parameters used for key generation. + + It's a compilation artifact that describes which and how public and private keys should be generated, + and used to encrypt arguments of the compiled function. + """ + + def __init__(self, client_parameters: _ClientParameters): + """Wrap the native Cpp object. + + Args: + client_parameters (_ClientParameters): object to wrap + + Raises: + TypeError: if client_parameters is not of type _ClientParameters + """ + if not isinstance(client_parameters, _ClientParameters): + raise TypeError( + f"client_parameters must be of type _ClientParameters, not {type(client_parameters)}" + ) + super().__init__(client_parameters) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py new file mode 100644 index 000000000..b4ad5c8f6 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py @@ -0,0 +1,191 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""Client support.""" +from typing import List, Optional, Union +import numpy as np + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport + +# pylint: enable=no-name-in-module,import-error + +from .public_result import PublicResult +from .key_set import KeySet +from .key_set_cache import KeySetCache +from .client_parameters import ClientParameters +from .public_arguments import PublicArguments +from .lambda_argument import LambdaArgument +from .wrapper import WrapperCpp +from .utils import ACCEPTED_INTS, ACCEPTED_NUMPY_UINTS, ACCEPTED_TYPES + + +class ClientSupport(WrapperCpp): + """Client interface for doing key generation and encryption. + + It provides features that are needed on the client side: + - Generation of public and private keys required for the encrypted computation + - Encryption and preparation of public arguments, used later as input to the computation + - Decryption of public result returned after execution + """ + + def __init__(self, client_support: _ClientSupport): + """Wrap the native Cpp object. + + Args: + client_support (_ClientSupport): object to wrap + + Raises: + TypeError: if client_support is not of type _ClientSupport + """ + if not isinstance(client_support, _ClientSupport): + raise TypeError( + f"client_support must be of type _ClientSupport not {type(client_support)}" + ) + super().__init__(client_support) + + @staticmethod + # pylint: disable=arguments-differ + def new() -> "ClientSupport": + """Build a ClientSupport. + + Returns: + ClientSupport + """ + return ClientSupport.wrap(_ClientSupport()) + + # pylint: enable=arguments-differ + + @staticmethod + def key_set( + client_parameters: ClientParameters, keyset_cache: Optional[KeySetCache] = None + ) -> KeySet: + """Generate a key set according to the client parameters. + + If the cache is set, and include equivalent keys as specified by the client parameters, the keyset + is loaded, otherwise, a new keyset is generated and saved in the cache. + + Args: + client_parameters (ClientParameters): client parameters specification + keyset_cache (Optional[KeySetCache], optional): keyset cache. Defaults to None. + + Raises: + TypeError: if client_parameters is not of type ClientParameters + TypeError: if keyset_cache is not of type KeySetCache + + Returns: + KeySet: generated or loaded keyset + """ + if keyset_cache is not None and not isinstance(keyset_cache, KeySetCache): + raise TypeError( + f"keyset_cache must be None or of type KeySetCache, not {type(keyset_cache)}" + ) + cpp_cache = None if keyset_cache is None else keyset_cache.cpp() + return KeySet.wrap(_ClientSupport.key_set(client_parameters.cpp(), cpp_cache)) + + @staticmethod + def encrypt_arguments( + client_parameters: ClientParameters, + keyset: KeySet, + args: List[Union[int, np.ndarray]], + ) -> PublicArguments: + """Prepare arguments for encrypted computation. + + Pack public arguments by encrypting the ones that requires encryption, and leaving the rest as plain. + It also pack public materials (public keys) that are required during the computation. + + Args: + client_parameters (ClientParameters): client parameters specification + keyset (KeySet): keyset used to encrypt arguments that require encryption + args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments + + Raises: + TypeError: if client_parameters is not of type ClientParameters + TypeError: if keyset is not of type KeySet + + Returns: + PublicArguments: public arguments for execution + """ + if not isinstance(client_parameters, ClientParameters): + raise TypeError( + f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" + ) + if not isinstance(keyset, KeySet): + raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}") + lambda_arguments = [ClientSupport._create_lambda_argument(arg) for arg in args] + return PublicArguments.wrap( + _ClientSupport.encrypt_arguments( + client_parameters.cpp(), + keyset.cpp(), + [arg.cpp() for arg in lambda_arguments], + ) + ) + + @staticmethod + def decrypt_result( + keyset: KeySet, public_result: PublicResult + ) -> Union[int, np.ndarray]: + """Decrypt a public result using the keyset. + + Args: + keyset (KeySet): keyset used for decryption + public_result: public result to decrypt + + Raises: + TypeError: if keyset is not of type KeySet + TypeError: if public_result is not of type PublicResult + RuntimeError: if the result is of an unknown type + + Returns: + Union[int, np.ndarray]: plain result + """ + if not isinstance(keyset, KeySet): + raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}") + if not isinstance(public_result, PublicResult): + raise TypeError( + f"public_result must be of type PublicResult, not {type(public_result)}" + ) + lambda_arg = LambdaArgument.wrap( + _ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp()) + ) + if lambda_arg.is_scalar(): + return lambda_arg.get_scalar() + if lambda_arg.is_tensor(): + shape = lambda_arg.get_tensor_shape() + tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape) + return tensor + raise RuntimeError("unknown return type") + + @staticmethod + def _create_lambda_argument(value: Union[int, np.ndarray]) -> LambdaArgument: + """Create a lambda argument holding either an int or tensor value. + + Args: + value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array + + Raises: + TypeError: if the values aren't in the expected range, or using a wrong type + + Returns: + LambdaArgument: lambda argument holding the appropriate value + """ + if not isinstance(value, ACCEPTED_TYPES): + raise TypeError( + "value of lambda argument must be either int, numpy.array or numpy.uint{8,16,32,64}" + ) + if isinstance(value, ACCEPTED_INTS): + if isinstance(value, int) and not 0 <= value < np.iinfo(np.uint64).max: + raise TypeError( + "single integer must be in the range [0, 2**64 - 1] (uint64)" + ) + return LambdaArgument.from_scalar(value) + assert isinstance(value, np.ndarray) + if value.dtype not in ACCEPTED_NUMPY_UINTS: + raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") + if value.shape == (): + if isinstance(value, np.ndarray): + # extract the single element + value = value.max() + # should be a single uint here + return LambdaArgument.from_scalar(value) + return LambdaArgument.from_tensor(value.flatten().tolist(), value.shape) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py new file mode 100644 index 000000000..338885b85 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py @@ -0,0 +1,122 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""CompilationOptions.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + CompilationOptions as _CompilationOptions, +) +from .wrapper import WrapperCpp + +# pylint: enable=no-name-in-module,import-error + + +class CompilationOptions(WrapperCpp): + """CompilationOptions holds different flags and options of the compilation process. + + It controls different parallelization flags, diagnostic verification, and also the name of entrypoint + function. + """ + + def __init__(self, compilation_options: _CompilationOptions): + """Wrap the native Cpp object. + + Args: + compilation_options (_CompilationOptions): object to wrap + + Raises: + TypeError: if compilation_options is not of type _CompilationOptions + """ + if not isinstance(compilation_options, _CompilationOptions): + raise TypeError( + f"_compilation_options must be of type _CompilationOptions, not {type(compilation_options)}" + ) + super().__init__(compilation_options) + + @staticmethod + # pylint: disable=arguments-differ + def new(function_name="main") -> "CompilationOptions": + """Build a CompilationOptions. + + Args: + function_name (str, optional): name of the entrypoint function. Defaults to "main". + + Raises: + TypeError: if function_name is not an str + + Returns: + CompilationOptions + """ + if not isinstance(function_name, str): + raise TypeError( + f"function_name must be of type str not {type(function_name)}" + ) + return CompilationOptions.wrap(_CompilationOptions(function_name)) + + # pylint: enable=arguments-differ + + def set_auto_parallelize(self, auto_parallelize: bool): + """Set option for auto parallelization. + + Args: + auto_parallelize (bool): whether to turn it on or off + + Raises: + TypeError: if the value to set is not boolean + """ + if not isinstance(auto_parallelize, bool): + raise TypeError("can't set the option to a non-boolean value") + self.cpp().set_auto_parallelize(auto_parallelize) + + def set_loop_parallelize(self, loop_parallelize: bool): + """Set option for loop parallelization. + + Args: + loop_parallelize (bool): whether to turn it on or off + + Raises: + TypeError: if the value to set is not boolean + """ + if not isinstance(loop_parallelize, bool): + raise TypeError("can't set the option to a non-boolean value") + self.cpp().set_loop_parallelize(loop_parallelize) + + def set_verify_diagnostics(self, verify_diagnostics: bool): + """Set option for diagnostics verification. + + Args: + verify_diagnostics (bool): whether to turn it on or off + + Raises: + TypeError: if the value to set is not boolean + """ + if not isinstance(verify_diagnostics, bool): + raise TypeError("can't set the option to a non-boolean value") + self.cpp().set_verify_diagnostics(verify_diagnostics) + + def set_dataflow_parallelize(self, dataflow_parallelize: bool): + """Set option for dataflow parallelization. + + Args: + dataflow_parallelize (bool): whether to turn it on or off + + Raises: + TypeError: if the value to set is not boolean + """ + if not isinstance(dataflow_parallelize, bool): + raise TypeError("can't set the option to a non-boolean value") + self.cpp().set_dataflow_parallelize(dataflow_parallelize) + + def set_funcname(self, funcname: str): + """Set entrypoint function name. + + Args: + funcname (str): name of the entrypoint function + + Raises: + TypeError: if the value to set is not str + """ + if not isinstance(funcname, str): + raise TypeError("can't set the option to a non-str value") + self.cpp().set_funcname(funcname) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/jit_compilation_result.py b/compiler/lib/Bindings/Python/concrete/compiler/jit_compilation_result.py new file mode 100644 index 000000000..9be9757d3 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/jit_compilation_result.py @@ -0,0 +1,38 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""JITCompilationResult.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + JITCompilationResult as _JITCompilationResult, +) + +# pylint: enable=no-name-in-module,import-error + + +from .wrapper import WrapperCpp + + +class JITCompilationResult(WrapperCpp): + """JITCompilationResult holds the result of a JIT compilation. + + It can be instrumented using the JITLambdaSupport to load client parameters and execute the compiled + code. + """ + + def __init__(self, jit_compilation_result: _JITCompilationResult): + """Wrap the native Cpp object. + + Args: + jit_compilation_result (_JITCompilationResult): object to wrap + + Raises: + TypeError: if jit_compilation_result is not of type _JITCompilationResult + """ + if not isinstance(jit_compilation_result, _JITCompilationResult): + raise TypeError( + f"jit_compilation_result must be of type _JITCompilationResult, not " + f"{type(jit_compilation_result)}" + ) + super().__init__(jit_compilation_result) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/jit_lambda.py b/compiler/lib/Bindings/Python/concrete/compiler/jit_lambda.py new file mode 100644 index 000000000..c1d6cf13d --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/jit_lambda.py @@ -0,0 +1,35 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""JITLambda.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + JITLambda as _JITLambda, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class JITLambda(WrapperCpp): + """JITLambda contains an in-memory executable code and can be ran using JITLambdaSupport. + + It's an artifact of JIT compilation, which stays in memory and can be executed with the help of + JITLambdaSupport. + """ + + def __init__(self, jit_lambda: _JITLambda): + """Wrap the native Cpp object. + + Args: + jit_lambda (_JITLambda): object to wrap + + Raises: + TypeError: if jit_lambda is not of type JITLambda + """ + if not isinstance(jit_lambda, _JITLambda): + raise TypeError( + f"jit_lambda must be of type _JITLambda, not {type(jit_lambda)}" + ) + super().__init__(jit_lambda) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/jit_lambda_support.py b/compiler/lib/Bindings/Python/concrete/compiler/jit_lambda_support.py new file mode 100644 index 000000000..f668bb0e0 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/jit_lambda_support.py @@ -0,0 +1,167 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""JITLambdaSupport. + +Just-in-time compilation provide a way to compile and execute an MLIR program while keeping the executable +code in memory. +""" + +from typing import Optional + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + JITLambdaSupport as _JITLambdaSupport, +) + +# pylint: enable=no-name-in-module,import-error +from .utils import lookup_runtime_lib +from .compilation_options import CompilationOptions +from .jit_compilation_result import JITCompilationResult +from .client_parameters import ClientParameters +from .jit_lambda import JITLambda +from .public_arguments import PublicArguments +from .public_result import PublicResult +from .wrapper import WrapperCpp + + +class JITLambdaSupport(WrapperCpp): + """Support class for JIT compilation and execution.""" + + def __init__(self, jit_lambda_support: _JITLambdaSupport): + """Wrap the native Cpp object. + + Args: + jit_lambda_support (_JITLambdaSupport): object to wrap + + Raises: + TypeError: if jit_lambda_support is not of type _JITLambdaSupport + """ + if not isinstance(jit_lambda_support, _JITLambdaSupport): + raise TypeError( + f"jit_lambda_support must be of type _JITLambdaSupport not{type(jit_lambda_support)}" + ) + super().__init__(jit_lambda_support) + + @staticmethod + # pylint: disable=arguments-differ + def new(runtime_lib_path: Optional[str] = None) -> "JITLambdaSupport": + """Build a JITLambdaSupport. + + Args: + runtime_lib_path (Optional[str]): path to the runtime library. Defaults to None. + + Raises: + TypeError: if runtime_lib_path is not of type str or None + + Returns: + JITLambdaSupport + """ + if runtime_lib_path is None: + runtime_lib_path = lookup_runtime_lib() + else: + if not isinstance(runtime_lib_path, str): + raise TypeError( + f"runtime_lib_path must be of type str, not {type(runtime_lib_path)}" + ) + return JITLambdaSupport.wrap(_JITLambdaSupport(runtime_lib_path)) + + # pylint: enable=arguments-differ + + def compile( + self, + mlir_program: str, + options: CompilationOptions = CompilationOptions.new("main"), + ) -> JITCompilationResult: + """JIT compile an MLIR program using Concrete dialects. + + Args: + mlir_program (str): textual representation of the mlir program to compile + options (CompilationOptions): compilation options + + Raises: + TypeError: if mlir_program is not of type str + TypeError: if options is not of type CompilationOptions + + Returns: + JITCompilationResult: the result of the JIT compilation + """ + if not isinstance(mlir_program, str): + raise TypeError( + f"mlir_program must be of type str, not {type(mlir_program)}" + ) + if not isinstance(options, CompilationOptions): + raise TypeError( + f"options must be of type CompilationOptions, not {type(options)}" + ) + return JITCompilationResult.wrap( + self.cpp().compile(mlir_program, options.cpp()) + ) + + def load_client_parameters( + self, compilation_result: JITCompilationResult + ) -> ClientParameters: + """Load the client parameters from the JIT compilation result. + + Args: + compilation_result (JITCompilationResult): result of the JIT compilation + + Raises: + TypeError: if compilation_result is not of type JITCompilationResult + + Returns: + ClientParameters: appropriate client parameters for the compiled program + """ + if not isinstance(compilation_result, JITCompilationResult): + raise TypeError( + f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}" + ) + return ClientParameters.wrap( + self.cpp().load_client_parameters(compilation_result.cpp()) + ) + + def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda: + """Load the JITLambda from the JIT compilation result. + + Args: + compilation_result (JITCompilationResult): result of the JIT compilation. + + Raises: + TypeError: if compilation_result is not of type JITCompilationResult + + Returns: + JITLambda: loaded JITLambda to be executed + """ + if not isinstance(compilation_result, JITCompilationResult): + raise TypeError( + f"compilation_result must be a JITCompilationResult not {type(compilation_result)}" + ) + return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp())) + + def server_call( + self, jit_lambda: JITLambda, public_arguments: PublicArguments + ) -> PublicResult: + """Call the JITLambda with public_arguments. + + Args: + jit_lambda (JITLambda): A server lambda to call. + public_arguments (PublicArguments): The arguments of the call. + + Raises: + TypeError: if jit_lambda is not of type JITLambda + TypeError: if public_arguments is not of type PublicArguments + + Returns: + PublicResult: the result of the call of the server lambda. + """ + if not isinstance(jit_lambda, JITLambda): + raise TypeError( + f"jit_lambda must be of type JITLambda, not {type(jit_lambda)}" + ) + if not isinstance(public_arguments, PublicArguments): + raise TypeError( + f"public_arguments must be of type PublicArguments, not {type(public_arguments)}" + ) + return PublicResult.wrap( + self.cpp().server_call(jit_lambda.cpp(), public_arguments.cpp()) + ) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/key_set.py b/compiler/lib/Bindings/Python/concrete/compiler/key_set.py new file mode 100644 index 000000000..815ef5f45 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/key_set.py @@ -0,0 +1,36 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + + +"""KeySet. + +Store for the different keys required for an encrypted computation. +""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + KeySet as _KeySet, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class KeySet(WrapperCpp): + """KeySet stores the different keys required for an encrypted computation. + + Holds private keys (secret key) used for encryption/decryption, and public keys used for computation. + """ + + def __init__(self, keyset: _KeySet): + """Wrap the native Cpp object. + + Args: + keyset (_KeySet): object to wrap + + Raises: + TypeError: if keyset is not of type _KeySet + """ + if not isinstance(keyset, _KeySet): + raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}") + super().__init__(keyset) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/key_set_cache.py b/compiler/lib/Bindings/Python/concrete/compiler/key_set_cache.py new file mode 100644 index 000000000..4ceab69cd --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/key_set_cache.py @@ -0,0 +1,59 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""KeySetCache. + +Cache for keys to avoid generating similar keys multiple times. +""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + KeySetCache as _KeySetCache, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class KeySetCache(WrapperCpp): + """KeySetCache is a cache for KeySet to avoid generating similar keys multiple times. + + Keys get cached and can be later used instead of generating a new keyset which can take a lot of time. + """ + + def __init__(self, keyset_cache: _KeySetCache): + """Wrap the native Cpp object. + + Args: + keyset_cache (_KeySetCache): object to wrap + + Raises: + TypeError: if keyset_cache is not of type _KeySetCache + """ + if not isinstance(keyset_cache, _KeySetCache): + raise TypeError( + f"key_set_cache must be of type _KeySetCache, not {type(keyset_cache)}" + ) + super().__init__(keyset_cache) + + @staticmethod + # pylint: disable=arguments-differ + def new(cache_path: str) -> "KeySetCache": + """Build a KeySetCache located at cache_path. + + Args: + cache_path (str): path to the cache + + Raises: + TypeError: if the path is not of type str. + + Returns: + KeySetCache + """ + if not isinstance(cache_path, str): + raise TypeError( + f"cache_path must to be of type str, not {type(cache_path)}" + ) + return KeySetCache.wrap(_KeySetCache(cache_path)) + + # pylint: enable=arguments-differ diff --git a/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py b/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py new file mode 100644 index 000000000..0a2aebd2e --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py @@ -0,0 +1,116 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""LambdaArgument.""" +from typing import List + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + LambdaArgument as _LambdaArgument, +) + +# pylint: enable=no-name-in-module,import-error +from .utils import ACCEPTED_INTS +from .wrapper import WrapperCpp + + +class LambdaArgument(WrapperCpp): + """LambdaArgument holds scalar or tensor values.""" + + def __init__(self, lambda_argument: _LambdaArgument): + """Wrap the native Cpp object. + + Args: + lambda_argument (_LambdaArgument): object to wrap + + Raises: + TypeError: if lambda_argument is not of type _LambdaArgument + """ + if not isinstance(lambda_argument, _LambdaArgument): + raise TypeError( + f"lambda_argument must be of type _LambdaArgument, not {type(lambda_argument)}" + ) + super().__init__(lambda_argument) + + @staticmethod + def new(*args, **kwargs): + """Use from_scalar or from_tensor instead. + + Raises: + RuntimeError + """ + raise RuntimeError( + "you should call from_scalar or from_tensor according to the argument type" + ) + + @staticmethod + def from_scalar(scalar: int) -> "LambdaArgument": + """Build a LambdaArgument containing the given scalar value. + + Args: + scalar (int or numpy.uint): scalar value to embed in LambdaArgument + + Raises: + TypeError: if scalar is not of type int or numpy.uint + + Returns: + LambdaArgument + """ + if not isinstance(scalar, ACCEPTED_INTS): + raise TypeError( + f"scalar must be of type int or numpy.uint, not {type(scalar)}" + ) + return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar)) + + @staticmethod + def from_tensor(data: List[int], shape: List[int]) -> "LambdaArgument": + """Build a LambdaArgument containing the given tensor. + + Args: + data (List[int]): flattened tensor data + shape (List[int]): shape of original tensor before flattening + + Returns: + LambdaArgument + """ + return LambdaArgument.wrap(_LambdaArgument.from_tensor(data, shape)) + + def is_scalar(self) -> bool: + """Check if the contained argument is a scalar. + + Returns: + bool + """ + return self.cpp().is_scalar() + + def get_scalar(self) -> int: + """Return the contained scalar value. + + Returns: + int + """ + return self.cpp().get_scalar() + + def is_tensor(self) -> bool: + """Check if the contained argument is a tensor. + + Returns: + bool + """ + return self.cpp().is_tensor() + + def get_tensor_shape(self) -> List[int]: + """Return the shape of the contained tensor. + + Returns: + List[int]: tensor shape + """ + return self.cpp().get_tensor_shape() + + def get_tensor_data(self) -> List[int]: + """Return the contained flattened tensor data. + + Returns: + List[int] + """ + return self.cpp().get_tensor_data() diff --git a/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py b/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py new file mode 100644 index 000000000..78f2fb68f --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py @@ -0,0 +1,60 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""LibraryCompilationResult.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + LibraryCompilationResult as _LibraryCompilationResult, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class LibraryCompilationResult(WrapperCpp): + """LibraryCompilationResult holds the result of the library compilation.""" + + def __init__(self, library_compilation_result: _LibraryCompilationResult): + """Wrap the native Cpp object. + + Args: + library_compilation_result (_LibraryCompilationResult): object to wrap + + Raises: + TypeError: if library_compilation_result is not of type _LibraryCompilationResult + """ + if not isinstance(library_compilation_result, _LibraryCompilationResult): + raise TypeError( + f"library_compilation_result must be of type _LibraryCompilationResult, not " + f"{type(library_compilation_result)}" + ) + super().__init__(library_compilation_result) + + @staticmethod + # pylint: disable=arguments-differ + def new(library_path: str, func_name: str) -> "LibraryCompilationResult": + """Build a LibraryCompilationResult at library_path, with func_name as entrypoint. + + Args: + library_path (str): path to the library + func_name (str): entrypoint function name + + Raises: + TypeError: if library_path is not of type str + TypeError: if func_name is not of type str + + Returns: + LibraryCompilationResult + """ + if not isinstance(library_path, str): + raise TypeError( + f"library_path must be of type str, not {type(library_path)}" + ) + if not isinstance(func_name, str): + raise TypeError(f"func_name must be of type str, not {type(func_name)}") + return LibraryCompilationResult.wrap( + _LibraryCompilationResult(library_path, func_name) + ) + + # pylint: enable=arguments-differ diff --git a/compiler/lib/Bindings/Python/concrete/compiler/library_lambda.py b/compiler/lib/Bindings/Python/concrete/compiler/library_lambda.py new file mode 100644 index 000000000..e34645083 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/library_lambda.py @@ -0,0 +1,31 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""LibraryLambda.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + LibraryLambda as _LibraryLambda, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class LibraryLambda(WrapperCpp): + """LibraryLambda reference a compiled library and can be ran using LibraryLambdaSupport.""" + + def __init__(self, library_lambda: _LibraryLambda): + """Wrap the native Cpp object. + + Args: + library_lambda (_LibraryLambda): object to wrap + + Raises: + TypeError: if library_lambda is not of type _LibraryLambda + """ + if not isinstance(library_lambda, _LibraryLambda): + raise TypeError( + f"library_lambda must be of type _LibraryLambda, not {type(library_lambda)}" + ) + super().__init__(library_lambda) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/library_lambda_support.py b/compiler/lib/Bindings/Python/concrete/compiler/library_lambda_support.py new file mode 100644 index 000000000..f12338541 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/library_lambda_support.py @@ -0,0 +1,202 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""LibraryCompilerSupport. + +Library compilation provides a way to compile an MLIR program into a library that can be later loaded +to execute the compiled code. +""" +import os + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + LibraryLambdaSupport as _LibraryLambdaSupport, +) + +# pylint: enable=no-name-in-module,import-error +from .compilation_options import CompilationOptions +from .library_compilation_result import LibraryCompilationResult +from .public_arguments import PublicArguments +from .library_lambda import LibraryLambda +from .public_result import PublicResult +from .client_parameters import ClientParameters +from .wrapper import WrapperCpp + + +# Default output path for compiled libraries +DEFAULT_OUTPUT_PATH = os.path.abspath( + os.path.join(os.path.curdir, "concrete-compiler_output_lib") +) + + +class LibraryLambdaSupport(WrapperCpp): + """Support class for library compilation and execution.""" + + def __init__(self, library_lambda_support: _LibraryLambdaSupport): + """Wrap the native Cpp object. + + Args: + library_lambda_support (_LibraryLambdaSupport): object to wrap + + Raises: + TypeError: if library_lambda_support is not of type _LibraryLambdaSupport + """ + if not isinstance(library_lambda_support, _LibraryLambdaSupport): + raise TypeError( + f"library_lambda_support must be of type _LibraryLambdaSupport, not " + f"{type(library_lambda_support)}" + ) + super().__init__(library_lambda_support) + self.library_path = DEFAULT_OUTPUT_PATH + + @property + def library_path(self) -> str: + """Path where to store compiled libraries.""" + return self._library_path + + @library_path.setter + def library_path(self, path: str): + if not isinstance(path, str): + raise TypeError(f"path must be of type str, not {type(path)}") + self._library_path = path + + @staticmethod + # pylint: disable=arguments-differ + def new(output_path: str = DEFAULT_OUTPUT_PATH) -> "LibraryLambdaSupport": + """Build a LibraryLambdaSupport. + + Args: + output_path (str, optional): path where to store compiled libraries. + Defaults to DEFAULT_OUTPUT_PATH. + + Raises: + TypeError: if output_path is not of type str + + Returns: + LibraryLambdaSupport + """ + if not isinstance(output_path, str): + raise TypeError(f"output_path must be of type str, not {type(output_path)}") + library_lambda_support = LibraryLambdaSupport.wrap( + _LibraryLambdaSupport(output_path) + ) + library_lambda_support.library_path = output_path + return library_lambda_support + + def compile( + self, + mlir_program: str, + options: CompilationOptions = CompilationOptions.new("main"), + ) -> LibraryCompilationResult: + """Compile an MLIR program using Concrete dialects into a library. + + Args: + mlir_program (str): textual representation of the mlir program to compile + options (CompilationOptions): compilation options + + Raises: + TypeError: if mlir_program is not of type str + TypeError: if options is not of type CompilationOptions + + Returns: + LibraryCompilationResult: the result of the library compilation + """ + if not isinstance(mlir_program, str): + raise TypeError( + f"mlir_program must be of type str, not {type(mlir_program)}" + ) + if not isinstance(options, CompilationOptions): + raise TypeError( + f"options must be of type CompilationOptions, not {type(options)}" + ) + return LibraryCompilationResult.wrap( + self.cpp().compile(mlir_program, options.cpp()) + ) + + def reload(self, func_name: str = "main") -> LibraryCompilationResult: + """Reload the library compilation result from the library_path. + + Args: + func_name: entrypoint function name + + Returns: + LibraryCompilationResult: loaded library + """ + if not isinstance(func_name, str): + raise TypeError(f"func_name must be of type str, not {type(func_name)}") + return LibraryCompilationResult.new(self.library_path, func_name) + + def load_client_parameters( + self, library_compilation_result: LibraryCompilationResult + ) -> ClientParameters: + """Load the client parameters from the library compilation result. + + Args: + library_compilation_result (LibraryCompilationResult): compilation result of the library + + Raises: + TypeError: if library_compilation_result is not of type LibraryCompilationResult + + Returns: + ClientParameters: appropriate client parameters for the compiled library + """ + if not isinstance(library_compilation_result, LibraryCompilationResult): + raise TypeError( + f"library_compilation_result must be of type LibraryCompilationResult, not " + f"{type(library_compilation_result)}" + ) + + return ClientParameters.wrap( + self.cpp().load_client_parameters(library_compilation_result.cpp()) + ) + + def load_server_lambda( + self, library_compilation_result: LibraryCompilationResult + ) -> LibraryLambda: + """Load the server lambda from the library compilation result. + + Args: + library_compilation_result (LibraryCompilationResult): compilation result of the library + + Raises: + TypeError: if library_compilation_result is not of type LibraryCompilationResult + + Returns: + LibraryLambda: executable reference to the library + """ + if not isinstance(library_compilation_result, LibraryCompilationResult): + raise TypeError( + f"library_compilation_result must be of type LibraryCompilationResult, not " + f"{type(library_compilation_result)}" + ) + return LibraryLambda.wrap( + self.cpp().load_server_lambda(library_compilation_result.cpp()) + ) + + def server_call( + self, library_lambda: LibraryLambda, public_arguments: PublicArguments + ) -> PublicResult: + """Call the library with public_arguments. + + Args: + library_lambda (LibraryLambda): reference to the compiled library + public_arguments (PublicArguments): arguments to use for execution + + Raises: + TypeError: if library_lambda is not of type LibraryLambda + TypeError: if public_arguments is not of type PublicArguments + + Returns: + PublicResult: result of the execution + """ + if not isinstance(library_lambda, LibraryLambda): + raise TypeError( + f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}" + ) + if not isinstance(public_arguments, PublicArguments): + raise TypeError( + f"public_arguments must be of type PublicArguments, not {type(public_arguments)}" + ) + return PublicResult.wrap( + self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp()) + ) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py b/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py new file mode 100644 index 000000000..8c4103286 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py @@ -0,0 +1,35 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""PublicArguments.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + PublicArguments as _PublicArguments, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class PublicArguments(WrapperCpp): + """PublicArguments holds encrypted and plain arguments, as well as public materials. + + An encrypted computation may require both encrypted and plain arguments, PublicArguments holds both + types, but also other public materials, such as public keys, which are required for private computation. + """ + + def __init__(self, public_arguments: _PublicArguments): + """Wrap the native Cpp object. + + Args: + public_arguments (_PublicArguments): object to wrap + + Raises: + TypeError: if public_arguments is not of type _PublicArguments + """ + if not isinstance(public_arguments, _PublicArguments): + raise TypeError( + f"public_arguments must be of type _PublicArguments, not {type(public_arguments)}" + ) + super().__init__(public_arguments) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/public_result.py b/compiler/lib/Bindings/Python/concrete/compiler/public_result.py new file mode 100644 index 000000000..36997275b --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/public_result.py @@ -0,0 +1,31 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""PublicResult.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + PublicResult as _PublicResult, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class PublicResult(WrapperCpp): + """PublicResult holds the result of an encrypted execution and can be decrypted using ClientSupport.""" + + def __init__(self, public_result: _PublicResult): + """Wrap the native Cpp object. + + Args: + public_result (_PublicResult): object to wrap + + Raises: + TypeError: if public_result is not of type _PublicResult + """ + if not isinstance(public_result, _PublicResult): + raise TypeError( + f"public_result must be of type _PublicResult, not {type(public_result)}" + ) + super().__init__(public_result) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/utils.py b/compiler/lib/Bindings/Python/concrete/compiler/utils.py new file mode 100644 index 000000000..ac61f75bc --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/utils.py @@ -0,0 +1,35 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""Common utils for the compiler submodule.""" +import os +import numpy as np + + +ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64) +ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS +ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS + + +def lookup_runtime_lib() -> str: + """Try to find the absolute path to the runtime library. + + Returns: + str: absolute path to the runtime library, or empty str if unsuccessful. + """ + # Go up to site-packages level + cwd = os.path.abspath(__file__) + cwd = os.path.abspath(os.path.join(cwd, os.pardir)) + cwd = os.path.abspath(os.path.join(cwd, os.pardir)) + package_name = "concrete_compiler" + libs_path = os.path.join(cwd, f"{package_name}.libs") + # Can be because it's not a properly installed package + if not os.path.exists(libs_path): + return "" + runtime_library_paths = [ + filename + for filename in os.listdir(libs_path) + if filename.startswith("libConcretelangRuntime") + ] + assert len(runtime_library_paths) == 1, "should be one and only one runtime library" + return os.path.join(libs_path, runtime_library_paths[0]) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/wrapper.py b/compiler/lib/Bindings/Python/concrete/compiler/wrapper.py new file mode 100644 index 000000000..ce7dd473e --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/wrapper.py @@ -0,0 +1,39 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""Wrapper for native Cpp objects.""" + + +class WrapperCpp: + """Wrapper base class for native Cpp objects. + + Initialization should mainly store the wrapped object, and future calls to the wrapper will be forwarded + to it. A static wrap method is provided to be more explicit. Wrappers should always be constructed using + the new method, which construct the Cpp object using the provided arguments, then wrap it. Classes that + inherit from this class should preferably type check the wrapped object during calls to init, and + reimplement the new method if the class is meant to be constructed. + """ + + def __init__(self, cpp_obj): + self._cpp_obj = cpp_obj + + @classmethod + def wrap(cls, cpp_obj) -> "WrapperCpp": + """Wrap the Cpp object into a Python object. + + Args: + cpp_obj: object to wrap + + Returns: + WrapperCpp: wrapper + """ + return cls(cpp_obj) + + @staticmethod + def new(*args, **kwargs): + """Create a new wrapper by building the underlying object with a specific set of arguments.""" + raise RuntimeError("This class shouldn't be built") + + def cpp(self): + """Return the Cpp wrapped object.""" + return self._cpp_obj