mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: recreate the Python API using wrappers
All Cpp objects are now wrapped, and calls are being forwarded after strict type checking for avoiding weird behaviors. The same Cpp API is now exposed to Python
This commit is contained in:
@@ -26,7 +26,23 @@ declare_mlir_python_extension(ConcretelangBindingsPythonExtension.Core
|
||||
declare_mlir_python_sources(ConcretelangBindingsPythonSources
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
SOURCES
|
||||
concrete/compiler.py
|
||||
concrete/compiler/__init__.py
|
||||
concrete/compiler/client_parameters.py
|
||||
concrete/compiler/client_support.py
|
||||
concrete/compiler/compilation_options.py
|
||||
concrete/compiler/jit_compilation_result.py
|
||||
concrete/compiler/jit_lambda_support.py
|
||||
concrete/compiler/jit_lambda.py
|
||||
concrete/compiler/key_set_cache.py
|
||||
concrete/compiler/key_set.py
|
||||
concrete/compiler/lambda_argument.py
|
||||
concrete/compiler/library_compilation_result.py
|
||||
concrete/compiler/library_lambda_support.py
|
||||
concrete/compiler/library_lambda.py
|
||||
concrete/compiler/public_arguments.py
|
||||
concrete/compiler/public_result.py
|
||||
concrete/compiler/utils.py
|
||||
concrete/compiler/wrapper.py
|
||||
concrete/__init__.py
|
||||
concrete/lang/__init__.py
|
||||
concrete/lang/dialects/__init__.py
|
||||
|
||||
@@ -44,16 +44,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](CompilationOptions &options, bool b) {
|
||||
options.verifyDiagnostics = b;
|
||||
})
|
||||
.def("auto_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.autoParallelize = b; })
|
||||
.def("loop_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.loopParallelize = b; })
|
||||
.def("dataflow_parallelize", [](CompilationOptions &options, bool b) {
|
||||
.def("set_auto_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.autoParallelize = b; })
|
||||
.def("set_loop_parallelize", [](CompilationOptions &options,
|
||||
bool b) { options.loopParallelize = b; })
|
||||
.def("set_dataflow_parallelize", [](CompilationOptions &options, bool b) {
|
||||
options.dataflowParallelize = b;
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JitCompilationResult");
|
||||
m, "JITCompilationResult");
|
||||
pybind11::class_<mlir::concretelang::JITLambda,
|
||||
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
|
||||
"JITLambda");
|
||||
|
||||
@@ -1,380 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""Compiler submodule"""
|
||||
from collections.abc import Iterable
|
||||
import os
|
||||
import atexit
|
||||
from typing import List, Union
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
terminate_parallelization as _terminate_parallelization,
|
||||
)
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientParameters
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import KeySet
|
||||
from mlir._mlir_libs._concretelang._compiler import KeySetCache
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicResult
|
||||
from mlir._mlir_libs._concretelang._compiler import PublicArguments
|
||||
from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import CompilationOptions
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITLambdaSupport as _JITLambdaSupport,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambda
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LibraryLambdaSupport as _LibraryLambdaSupport,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
|
||||
import numpy as np
|
||||
|
||||
|
||||
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
|
||||
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
|
||||
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
|
||||
|
||||
|
||||
# Terminate parallelization in the compiler (if init) during cleanup
|
||||
atexit.register(_terminate_parallelization)
|
||||
|
||||
|
||||
def _lookup_runtime_lib() -> str:
|
||||
"""Try to find the absolute path to the runtime library.
|
||||
|
||||
Returns:
|
||||
str: absolute path to the runtime library, or empty str if unsuccessful.
|
||||
"""
|
||||
# Go up to site-packages level
|
||||
cwd = os.path.abspath(__file__)
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
|
||||
package_name = "concrete_compiler"
|
||||
libs_path = os.path.join(cwd, f"{package_name}.libs")
|
||||
# Can be because it's not a properly installed package
|
||||
if not os.path.exists(libs_path):
|
||||
return ""
|
||||
runtime_library_paths = [
|
||||
filename
|
||||
for filename in os.listdir(libs_path)
|
||||
if filename.startswith("libConcretelangRuntime")
|
||||
]
|
||||
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
return os.path.join(libs_path, runtime_library_paths[0])
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR code to parse.
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
|
||||
Returns:
|
||||
str: parsed MLIR input.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = JITCompilerSupport()
|
||||
self._lambda = None
|
||||
if mlir_str is not None:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(
|
||||
self,
|
||||
mlir_str: str,
|
||||
func_name: str = "main",
|
||||
unsecure_key_set_cache_path: str = None,
|
||||
auto_parallelize: bool = False,
|
||||
loop_parallelize: bool = False,
|
||||
df_parallelize: bool = False,
|
||||
):
|
||||
"""Compile the MLIR input.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR to compile.
|
||||
func_name (str): name of the function to set as entrypoint (default: main).
|
||||
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
|
||||
auto_parallelize (bool): whether to activate auto-parallelization or not (default: False),
|
||||
loop_parallelize (bool): whether to activate loop-parallelization or not (default: False),
|
||||
df_parallelize (bool): whether to activate dataflow-parallelization or not (default: False),
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
"""
|
||||
if not all(
|
||||
isinstance(flag, bool)
|
||||
for flag in [auto_parallelize, loop_parallelize, df_parallelize]
|
||||
):
|
||||
raise TypeError(
|
||||
"parallelization flags (auto_parallelize, loop_parallelize, df_parallelize), should be booleans"
|
||||
)
|
||||
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
|
||||
if not isinstance(unsecure_key_set_cache_path, str):
|
||||
raise TypeError("unsecure_key_set_cache_path must be a str")
|
||||
options = CompilationOptions(func_name)
|
||||
options.auto_parallelize(auto_parallelize)
|
||||
options.loop_parallelize(loop_parallelize)
|
||||
options.dataflow_parallelize(df_parallelize)
|
||||
self._compilation_result = self._engine.compile(mlir_str, options)
|
||||
self._client_parameters = self._engine.load_client_parameters(
|
||||
self._compilation_result
|
||||
)
|
||||
keyset_cache = None
|
||||
if not unsecure_key_set_cache_path is None:
|
||||
keyset_cache = KeySetCache(unsecure_key_set_cache_path)
|
||||
self._key_set = ClientSupport.key_set(self._client_parameters, keyset_cache)
|
||||
|
||||
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
|
||||
"""Run the compiled code.
|
||||
|
||||
Args:
|
||||
*args: list of arguments for execution. Each argument can be an int, or a numpy.array
|
||||
|
||||
Raises:
|
||||
TypeError: if execution arguments can't be constructed
|
||||
RuntimeError: if the engine has not compiled any code yet
|
||||
RuntimeError: if the return type is unknown
|
||||
|
||||
Returns:
|
||||
int or numpy.array: result of execution.
|
||||
"""
|
||||
if self._compilation_result is None:
|
||||
raise RuntimeError("need to compile an MLIR code first")
|
||||
# Client
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
self._client_parameters, self._key_set, args
|
||||
)
|
||||
# Server
|
||||
server_lambda = self._engine.load_server_lambda(self._compilation_result)
|
||||
public_result = self._engine.server_call(server_lambda, public_arguments)
|
||||
# Client
|
||||
return ClientSupport.decrypt_result(self._key_set, public_result)
|
||||
|
||||
|
||||
class ClientSupport:
|
||||
def key_set(
|
||||
client_parameters: ClientParameters, cache: KeySetCache = None
|
||||
) -> KeySet:
|
||||
"""Generates a key set according to the given client parameters.
|
||||
If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache
|
||||
|
||||
Args:
|
||||
client_parameters: A client parameters specification
|
||||
cache: An optional cache of key set.
|
||||
|
||||
Returns:
|
||||
KeySet: the key set
|
||||
"""
|
||||
return _ClientSupport.key_set(client_parameters, cache)
|
||||
|
||||
def encrypt_arguments(
|
||||
client_parameters: ClientParameters,
|
||||
key_set: KeySet,
|
||||
args: List[Union[int, np.ndarray]],
|
||||
) -> PublicArguments:
|
||||
"""Export clear arguments to public arguments.
|
||||
For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object.
|
||||
|
||||
Args:
|
||||
client_parameters: A client parameters specification
|
||||
key_set: A key set used to encrypt encrypted arguments
|
||||
|
||||
Returns:
|
||||
PublicArguments: the public arguments
|
||||
"""
|
||||
execution_arguments = [
|
||||
ClientSupport._create_execution_argument(arg) for arg in args
|
||||
]
|
||||
return _ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, execution_arguments
|
||||
)
|
||||
|
||||
def decrypt_result(
|
||||
key_set: KeySet, public_result: PublicResult
|
||||
) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result thanks the given key set.
|
||||
|
||||
Args:
|
||||
key_set: The key set used to decrypt the result.
|
||||
public_result: The public result to descrypt.
|
||||
|
||||
Returns:
|
||||
int or numpy.array: The result of decryption.
|
||||
"""
|
||||
lambda_arg = _ClientSupport.decrypt_result(key_set, public_result)
|
||||
if lambda_arg.is_scalar():
|
||||
return lambda_arg.get_scalar()
|
||||
elif lambda_arg.is_tensor():
|
||||
shape = lambda_arg.get_tensor_shape()
|
||||
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
|
||||
return tensor
|
||||
else:
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
def _create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument:
|
||||
"""Create an execution argument holding either an int or tensor value.
|
||||
|
||||
Args:
|
||||
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
|
||||
|
||||
Raises:
|
||||
TypeError: if the values aren't in the expected range, or using a wrong type
|
||||
|
||||
Returns:
|
||||
_LambdaArgument: lambda argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
|
||||
)
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
)
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
else:
|
||||
assert isinstance(value, np.ndarray)
|
||||
if value.shape == ():
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
|
||||
|
||||
class JITCompilerSupport:
|
||||
def __init__(self, runtime_lib_path=None):
|
||||
if runtime_lib_path is None:
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
else:
|
||||
if not isinstance(runtime_lib_path, str):
|
||||
raise TypeError(
|
||||
"runtime_lib_path must be an str representing the path to the runtime lib"
|
||||
)
|
||||
self._support = _JITLambdaSupport(runtime_lib_path)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions("main"),
|
||||
) -> JitCompilationResult:
|
||||
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
|
||||
|
||||
Args:
|
||||
mlir_program: A textual representation of the mlir program to compile.
|
||||
func_name: The name of the function to compile.
|
||||
|
||||
Returns:
|
||||
JITCompilationResult: the result of the JIT compilation.
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, options)
|
||||
|
||||
def load_client_parameters(
|
||||
self, compilation_result: JitCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
return self._support.load_client_parameters(compilation_result)
|
||||
|
||||
def load_server_lambda(self, compilation_result: JitCompilationResult) -> JITLambda:
|
||||
"""Load the server lambda from the JIT compilation result"""
|
||||
return self._support.load_server_lambda(compilation_result)
|
||||
|
||||
def server_call(self, server_lambda: JITLambda, public_arguments: PublicArguments):
|
||||
"""Call the server lambda with public_arguments
|
||||
|
||||
Args:
|
||||
server_lambda: A server lambda to call
|
||||
public_arguments: The arguments of the call
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda
|
||||
"""
|
||||
return self._support.server_call(server_lambda, public_arguments)
|
||||
|
||||
|
||||
class LibraryCompilerSupport:
|
||||
def __init__(self, outputPath="./out"):
|
||||
self._library_path = outputPath
|
||||
self._support = _LibraryLambdaSupport(outputPath)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions("main"),
|
||||
) -> LibraryCompilationResult:
|
||||
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
|
||||
|
||||
Args:
|
||||
mlir_program: A textual representation of the mlir program to compile.
|
||||
func_name: The name of the function to compile.
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: the result of the compilation.
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
if not isinstance(options, CompilationOptions):
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, options)
|
||||
|
||||
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
|
||||
"""Reload the library compilation result from the outputPath.
|
||||
Args:
|
||||
library-path: The path of the compiled library.
|
||||
func_name: The name of the compiled function.
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: the result of a compilation.
|
||||
"""
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError("func_name must be an `str`")
|
||||
return LibraryCompilationResult(self._library_path, func_name)
|
||||
|
||||
def load_client_parameters(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
if not isinstance(compilation_result, LibraryCompilationResult):
|
||||
raise TypeError("compilation_result must be an `LibraryCompilationResult`")
|
||||
|
||||
return self._support.load_client_parameters(compilation_result)
|
||||
|
||||
def load_server_lambda(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> LibraryLambda:
|
||||
"""Load the server lambda from the JIT compilation result"""
|
||||
return self._support.load_server_lambda(compilation_result)
|
||||
|
||||
def server_call(
|
||||
self, server_lambda: LibraryLambda, public_arguments: PublicArguments
|
||||
) -> PublicResult:
|
||||
"""Call the server lambda with public_arguments
|
||||
|
||||
Args:
|
||||
server_lambda: A server lambda to call
|
||||
public_arguments: The arguments of the call
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda
|
||||
"""
|
||||
return self._support.server_call(server_lambda, public_arguments)
|
||||
48
compiler/lib/Bindings/Python/concrete/compiler/__init__.py
Normal file
48
compiler/lib/Bindings/Python/concrete/compiler/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""Compiler submodule."""
|
||||
import atexit
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
terminate_parallelization as _terminate_parallelization,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
from .compilation_options import CompilationOptions
|
||||
from .key_set_cache import KeySetCache
|
||||
from .client_parameters import ClientParameters
|
||||
from .key_set import KeySet
|
||||
from .public_result import PublicResult
|
||||
from .public_arguments import PublicArguments
|
||||
from .jit_compilation_result import JITCompilationResult
|
||||
from .jit_lambda import JITLambda
|
||||
from .client_support import ClientSupport
|
||||
from .jit_lambda_support import JITLambdaSupport
|
||||
from .library_lambda_support import LibraryLambdaSupport
|
||||
|
||||
|
||||
# Terminate parallelization in the compiler (if init) during cleanup
|
||||
atexit.register(_terminate_parallelization)
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
|
||||
Useful to check the validity of an MLIR representation
|
||||
|
||||
Args:
|
||||
mlir_str (str): textual representation of an MLIR code
|
||||
|
||||
Raises:
|
||||
TypeError: if mlir_str is not of type str
|
||||
|
||||
Returns:
|
||||
str: textual representation of the MLIR code after parsing
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError(f"mlir_str must be of type str, not {type(mlir_str)}")
|
||||
return _round_trip(mlir_str)
|
||||
@@ -0,0 +1,36 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""Client parameters."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
ClientParameters as _ClientParameters,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class ClientParameters(WrapperCpp):
|
||||
"""ClientParameters are public parameters used for key generation.
|
||||
|
||||
It's a compilation artifact that describes which and how public and private keys should be generated,
|
||||
and used to encrypt arguments of the compiled function.
|
||||
"""
|
||||
|
||||
def __init__(self, client_parameters: _ClientParameters):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
client_parameters (_ClientParameters): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if client_parameters is not of type _ClientParameters
|
||||
"""
|
||||
if not isinstance(client_parameters, _ClientParameters):
|
||||
raise TypeError(
|
||||
f"client_parameters must be of type _ClientParameters, not {type(client_parameters)}"
|
||||
)
|
||||
super().__init__(client_parameters)
|
||||
191
compiler/lib/Bindings/Python/concrete/compiler/client_support.py
Normal file
191
compiler/lib/Bindings/Python/concrete/compiler/client_support.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""Client support."""
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
from .public_result import PublicResult
|
||||
from .key_set import KeySet
|
||||
from .key_set_cache import KeySetCache
|
||||
from .client_parameters import ClientParameters
|
||||
from .public_arguments import PublicArguments
|
||||
from .lambda_argument import LambdaArgument
|
||||
from .wrapper import WrapperCpp
|
||||
from .utils import ACCEPTED_INTS, ACCEPTED_NUMPY_UINTS, ACCEPTED_TYPES
|
||||
|
||||
|
||||
class ClientSupport(WrapperCpp):
|
||||
"""Client interface for doing key generation and encryption.
|
||||
|
||||
It provides features that are needed on the client side:
|
||||
- Generation of public and private keys required for the encrypted computation
|
||||
- Encryption and preparation of public arguments, used later as input to the computation
|
||||
- Decryption of public result returned after execution
|
||||
"""
|
||||
|
||||
def __init__(self, client_support: _ClientSupport):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
client_support (_ClientSupport): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if client_support is not of type _ClientSupport
|
||||
"""
|
||||
if not isinstance(client_support, _ClientSupport):
|
||||
raise TypeError(
|
||||
f"client_support must be of type _ClientSupport not {type(client_support)}"
|
||||
)
|
||||
super().__init__(client_support)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new() -> "ClientSupport":
|
||||
"""Build a ClientSupport.
|
||||
|
||||
Returns:
|
||||
ClientSupport
|
||||
"""
|
||||
return ClientSupport.wrap(_ClientSupport())
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
@staticmethod
|
||||
def key_set(
|
||||
client_parameters: ClientParameters, keyset_cache: Optional[KeySetCache] = None
|
||||
) -> KeySet:
|
||||
"""Generate a key set according to the client parameters.
|
||||
|
||||
If the cache is set, and include equivalent keys as specified by the client parameters, the keyset
|
||||
is loaded, otherwise, a new keyset is generated and saved in the cache.
|
||||
|
||||
Args:
|
||||
client_parameters (ClientParameters): client parameters specification
|
||||
keyset_cache (Optional[KeySetCache], optional): keyset cache. Defaults to None.
|
||||
|
||||
Raises:
|
||||
TypeError: if client_parameters is not of type ClientParameters
|
||||
TypeError: if keyset_cache is not of type KeySetCache
|
||||
|
||||
Returns:
|
||||
KeySet: generated or loaded keyset
|
||||
"""
|
||||
if keyset_cache is not None and not isinstance(keyset_cache, KeySetCache):
|
||||
raise TypeError(
|
||||
f"keyset_cache must be None or of type KeySetCache, not {type(keyset_cache)}"
|
||||
)
|
||||
cpp_cache = None if keyset_cache is None else keyset_cache.cpp()
|
||||
return KeySet.wrap(_ClientSupport.key_set(client_parameters.cpp(), cpp_cache))
|
||||
|
||||
@staticmethod
|
||||
def encrypt_arguments(
|
||||
client_parameters: ClientParameters,
|
||||
keyset: KeySet,
|
||||
args: List[Union[int, np.ndarray]],
|
||||
) -> PublicArguments:
|
||||
"""Prepare arguments for encrypted computation.
|
||||
|
||||
Pack public arguments by encrypting the ones that requires encryption, and leaving the rest as plain.
|
||||
It also pack public materials (public keys) that are required during the computation.
|
||||
|
||||
Args:
|
||||
client_parameters (ClientParameters): client parameters specification
|
||||
keyset (KeySet): keyset used to encrypt arguments that require encryption
|
||||
args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments
|
||||
|
||||
Raises:
|
||||
TypeError: if client_parameters is not of type ClientParameters
|
||||
TypeError: if keyset is not of type KeySet
|
||||
|
||||
Returns:
|
||||
PublicArguments: public arguments for execution
|
||||
"""
|
||||
if not isinstance(client_parameters, ClientParameters):
|
||||
raise TypeError(
|
||||
f"client_parameters must be of type ClientParameters, not {type(client_parameters)}"
|
||||
)
|
||||
if not isinstance(keyset, KeySet):
|
||||
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
|
||||
lambda_arguments = [ClientSupport._create_lambda_argument(arg) for arg in args]
|
||||
return PublicArguments.wrap(
|
||||
_ClientSupport.encrypt_arguments(
|
||||
client_parameters.cpp(),
|
||||
keyset.cpp(),
|
||||
[arg.cpp() for arg in lambda_arguments],
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def decrypt_result(
|
||||
keyset: KeySet, public_result: PublicResult
|
||||
) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result using the keyset.
|
||||
|
||||
Args:
|
||||
keyset (KeySet): keyset used for decryption
|
||||
public_result: public result to decrypt
|
||||
|
||||
Raises:
|
||||
TypeError: if keyset is not of type KeySet
|
||||
TypeError: if public_result is not of type PublicResult
|
||||
RuntimeError: if the result is of an unknown type
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray]: plain result
|
||||
"""
|
||||
if not isinstance(keyset, KeySet):
|
||||
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
|
||||
if not isinstance(public_result, PublicResult):
|
||||
raise TypeError(
|
||||
f"public_result must be of type PublicResult, not {type(public_result)}"
|
||||
)
|
||||
lambda_arg = LambdaArgument.wrap(
|
||||
_ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp())
|
||||
)
|
||||
if lambda_arg.is_scalar():
|
||||
return lambda_arg.get_scalar()
|
||||
if lambda_arg.is_tensor():
|
||||
shape = lambda_arg.get_tensor_shape()
|
||||
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
|
||||
return tensor
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
@staticmethod
|
||||
def _create_lambda_argument(value: Union[int, np.ndarray]) -> LambdaArgument:
|
||||
"""Create a lambda argument holding either an int or tensor value.
|
||||
|
||||
Args:
|
||||
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
|
||||
|
||||
Raises:
|
||||
TypeError: if the values aren't in the expected range, or using a wrong type
|
||||
|
||||
Returns:
|
||||
LambdaArgument: lambda argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of lambda argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
|
||||
)
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not 0 <= value < np.iinfo(np.uint64).max:
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
)
|
||||
return LambdaArgument.from_scalar(value)
|
||||
assert isinstance(value, np.ndarray)
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
if value.shape == ():
|
||||
if isinstance(value, np.ndarray):
|
||||
# extract the single element
|
||||
value = value.max()
|
||||
# should be a single uint here
|
||||
return LambdaArgument.from_scalar(value)
|
||||
return LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
@@ -0,0 +1,122 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""CompilationOptions."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
CompilationOptions as _CompilationOptions,
|
||||
)
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
class CompilationOptions(WrapperCpp):
|
||||
"""CompilationOptions holds different flags and options of the compilation process.
|
||||
|
||||
It controls different parallelization flags, diagnostic verification, and also the name of entrypoint
|
||||
function.
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_options: _CompilationOptions):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
compilation_options (_CompilationOptions): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_options is not of type _CompilationOptions
|
||||
"""
|
||||
if not isinstance(compilation_options, _CompilationOptions):
|
||||
raise TypeError(
|
||||
f"_compilation_options must be of type _CompilationOptions, not {type(compilation_options)}"
|
||||
)
|
||||
super().__init__(compilation_options)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(function_name="main") -> "CompilationOptions":
|
||||
"""Build a CompilationOptions.
|
||||
|
||||
Args:
|
||||
function_name (str, optional): name of the entrypoint function. Defaults to "main".
|
||||
|
||||
Raises:
|
||||
TypeError: if function_name is not an str
|
||||
|
||||
Returns:
|
||||
CompilationOptions
|
||||
"""
|
||||
if not isinstance(function_name, str):
|
||||
raise TypeError(
|
||||
f"function_name must be of type str not {type(function_name)}"
|
||||
)
|
||||
return CompilationOptions.wrap(_CompilationOptions(function_name))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
def set_auto_parallelize(self, auto_parallelize: bool):
|
||||
"""Set option for auto parallelization.
|
||||
|
||||
Args:
|
||||
auto_parallelize (bool): whether to turn it on or off
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not boolean
|
||||
"""
|
||||
if not isinstance(auto_parallelize, bool):
|
||||
raise TypeError("can't set the option to a non-boolean value")
|
||||
self.cpp().set_auto_parallelize(auto_parallelize)
|
||||
|
||||
def set_loop_parallelize(self, loop_parallelize: bool):
|
||||
"""Set option for loop parallelization.
|
||||
|
||||
Args:
|
||||
loop_parallelize (bool): whether to turn it on or off
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not boolean
|
||||
"""
|
||||
if not isinstance(loop_parallelize, bool):
|
||||
raise TypeError("can't set the option to a non-boolean value")
|
||||
self.cpp().set_loop_parallelize(loop_parallelize)
|
||||
|
||||
def set_verify_diagnostics(self, verify_diagnostics: bool):
|
||||
"""Set option for diagnostics verification.
|
||||
|
||||
Args:
|
||||
verify_diagnostics (bool): whether to turn it on or off
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not boolean
|
||||
"""
|
||||
if not isinstance(verify_diagnostics, bool):
|
||||
raise TypeError("can't set the option to a non-boolean value")
|
||||
self.cpp().set_verify_diagnostics(verify_diagnostics)
|
||||
|
||||
def set_dataflow_parallelize(self, dataflow_parallelize: bool):
|
||||
"""Set option for dataflow parallelization.
|
||||
|
||||
Args:
|
||||
dataflow_parallelize (bool): whether to turn it on or off
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not boolean
|
||||
"""
|
||||
if not isinstance(dataflow_parallelize, bool):
|
||||
raise TypeError("can't set the option to a non-boolean value")
|
||||
self.cpp().set_dataflow_parallelize(dataflow_parallelize)
|
||||
|
||||
def set_funcname(self, funcname: str):
|
||||
"""Set entrypoint function name.
|
||||
|
||||
Args:
|
||||
funcname (str): name of the entrypoint function
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not str
|
||||
"""
|
||||
if not isinstance(funcname, str):
|
||||
raise TypeError("can't set the option to a non-str value")
|
||||
self.cpp().set_funcname(funcname)
|
||||
@@ -0,0 +1,38 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""JITCompilationResult."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITCompilationResult as _JITCompilationResult,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class JITCompilationResult(WrapperCpp):
|
||||
"""JITCompilationResult holds the result of a JIT compilation.
|
||||
|
||||
It can be instrumented using the JITLambdaSupport to load client parameters and execute the compiled
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(self, jit_compilation_result: _JITCompilationResult):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
jit_compilation_result (_JITCompilationResult): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_compilation_result is not of type _JITCompilationResult
|
||||
"""
|
||||
if not isinstance(jit_compilation_result, _JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"jit_compilation_result must be of type _JITCompilationResult, not "
|
||||
f"{type(jit_compilation_result)}"
|
||||
)
|
||||
super().__init__(jit_compilation_result)
|
||||
35
compiler/lib/Bindings/Python/concrete/compiler/jit_lambda.py
Normal file
35
compiler/lib/Bindings/Python/concrete/compiler/jit_lambda.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""JITLambda."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITLambda as _JITLambda,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class JITLambda(WrapperCpp):
|
||||
"""JITLambda contains an in-memory executable code and can be ran using JITLambdaSupport.
|
||||
|
||||
It's an artifact of JIT compilation, which stays in memory and can be executed with the help of
|
||||
JITLambdaSupport.
|
||||
"""
|
||||
|
||||
def __init__(self, jit_lambda: _JITLambda):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
jit_lambda (_JITLambda): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_lambda is not of type JITLambda
|
||||
"""
|
||||
if not isinstance(jit_lambda, _JITLambda):
|
||||
raise TypeError(
|
||||
f"jit_lambda must be of type _JITLambda, not {type(jit_lambda)}"
|
||||
)
|
||||
super().__init__(jit_lambda)
|
||||
@@ -0,0 +1,167 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""JITLambdaSupport.
|
||||
|
||||
Just-in-time compilation provide a way to compile and execute an MLIR program while keeping the executable
|
||||
code in memory.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITLambdaSupport as _JITLambdaSupport,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .utils import lookup_runtime_lib
|
||||
from .compilation_options import CompilationOptions
|
||||
from .jit_compilation_result import JITCompilationResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .jit_lambda import JITLambda
|
||||
from .public_arguments import PublicArguments
|
||||
from .public_result import PublicResult
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class JITLambdaSupport(WrapperCpp):
|
||||
"""Support class for JIT compilation and execution."""
|
||||
|
||||
def __init__(self, jit_lambda_support: _JITLambdaSupport):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
jit_lambda_support (_JITLambdaSupport): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_lambda_support is not of type _JITLambdaSupport
|
||||
"""
|
||||
if not isinstance(jit_lambda_support, _JITLambdaSupport):
|
||||
raise TypeError(
|
||||
f"jit_lambda_support must be of type _JITLambdaSupport not{type(jit_lambda_support)}"
|
||||
)
|
||||
super().__init__(jit_lambda_support)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(runtime_lib_path: Optional[str] = None) -> "JITLambdaSupport":
|
||||
"""Build a JITLambdaSupport.
|
||||
|
||||
Args:
|
||||
runtime_lib_path (Optional[str]): path to the runtime library. Defaults to None.
|
||||
|
||||
Raises:
|
||||
TypeError: if runtime_lib_path is not of type str or None
|
||||
|
||||
Returns:
|
||||
JITLambdaSupport
|
||||
"""
|
||||
if runtime_lib_path is None:
|
||||
runtime_lib_path = lookup_runtime_lib()
|
||||
else:
|
||||
if not isinstance(runtime_lib_path, str):
|
||||
raise TypeError(
|
||||
f"runtime_lib_path must be of type str, not {type(runtime_lib_path)}"
|
||||
)
|
||||
return JITLambdaSupport.wrap(_JITLambdaSupport(runtime_lib_path))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions.new("main"),
|
||||
) -> JITCompilationResult:
|
||||
"""JIT compile an MLIR program using Concrete dialects.
|
||||
|
||||
Args:
|
||||
mlir_program (str): textual representation of the mlir program to compile
|
||||
options (CompilationOptions): compilation options
|
||||
|
||||
Raises:
|
||||
TypeError: if mlir_program is not of type str
|
||||
TypeError: if options is not of type CompilationOptions
|
||||
|
||||
Returns:
|
||||
JITCompilationResult: the result of the JIT compilation
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError(
|
||||
f"mlir_program must be of type str, not {type(mlir_program)}"
|
||||
)
|
||||
if not isinstance(options, CompilationOptions):
|
||||
raise TypeError(
|
||||
f"options must be of type CompilationOptions, not {type(options)}"
|
||||
)
|
||||
return JITCompilationResult.wrap(
|
||||
self.cpp().compile(mlir_program, options.cpp())
|
||||
)
|
||||
|
||||
def load_client_parameters(
|
||||
self, compilation_result: JITCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
ClientParameters: appropriate client parameters for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return ClientParameters.wrap(
|
||||
self.cpp().load_client_parameters(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda:
|
||||
"""Load the JITLambda from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation.
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
JITLambda: loaded JITLambda to be executed
|
||||
"""
|
||||
if not isinstance(compilation_result, JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be a JITCompilationResult not {type(compilation_result)}"
|
||||
)
|
||||
return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp()))
|
||||
|
||||
def server_call(
|
||||
self, jit_lambda: JITLambda, public_arguments: PublicArguments
|
||||
) -> PublicResult:
|
||||
"""Call the JITLambda with public_arguments.
|
||||
|
||||
Args:
|
||||
jit_lambda (JITLambda): A server lambda to call.
|
||||
public_arguments (PublicArguments): The arguments of the call.
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_lambda is not of type JITLambda
|
||||
TypeError: if public_arguments is not of type PublicArguments
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda.
|
||||
"""
|
||||
if not isinstance(jit_lambda, JITLambda):
|
||||
raise TypeError(
|
||||
f"jit_lambda must be of type JITLambda, not {type(jit_lambda)}"
|
||||
)
|
||||
if not isinstance(public_arguments, PublicArguments):
|
||||
raise TypeError(
|
||||
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
|
||||
)
|
||||
return PublicResult.wrap(
|
||||
self.cpp().server_call(jit_lambda.cpp(), public_arguments.cpp())
|
||||
)
|
||||
36
compiler/lib/Bindings/Python/concrete/compiler/key_set.py
Normal file
36
compiler/lib/Bindings/Python/concrete/compiler/key_set.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
"""KeySet.
|
||||
|
||||
Store for the different keys required for an encrypted computation.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
KeySet as _KeySet,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class KeySet(WrapperCpp):
|
||||
"""KeySet stores the different keys required for an encrypted computation.
|
||||
|
||||
Holds private keys (secret key) used for encryption/decryption, and public keys used for computation.
|
||||
"""
|
||||
|
||||
def __init__(self, keyset: _KeySet):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
keyset (_KeySet): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if keyset is not of type _KeySet
|
||||
"""
|
||||
if not isinstance(keyset, _KeySet):
|
||||
raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}")
|
||||
super().__init__(keyset)
|
||||
@@ -0,0 +1,59 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""KeySetCache.
|
||||
|
||||
Cache for keys to avoid generating similar keys multiple times.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
KeySetCache as _KeySetCache,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class KeySetCache(WrapperCpp):
|
||||
"""KeySetCache is a cache for KeySet to avoid generating similar keys multiple times.
|
||||
|
||||
Keys get cached and can be later used instead of generating a new keyset which can take a lot of time.
|
||||
"""
|
||||
|
||||
def __init__(self, keyset_cache: _KeySetCache):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
keyset_cache (_KeySetCache): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if keyset_cache is not of type _KeySetCache
|
||||
"""
|
||||
if not isinstance(keyset_cache, _KeySetCache):
|
||||
raise TypeError(
|
||||
f"key_set_cache must be of type _KeySetCache, not {type(keyset_cache)}"
|
||||
)
|
||||
super().__init__(keyset_cache)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(cache_path: str) -> "KeySetCache":
|
||||
"""Build a KeySetCache located at cache_path.
|
||||
|
||||
Args:
|
||||
cache_path (str): path to the cache
|
||||
|
||||
Raises:
|
||||
TypeError: if the path is not of type str.
|
||||
|
||||
Returns:
|
||||
KeySetCache
|
||||
"""
|
||||
if not isinstance(cache_path, str):
|
||||
raise TypeError(
|
||||
f"cache_path must to be of type str, not {type(cache_path)}"
|
||||
)
|
||||
return KeySetCache.wrap(_KeySetCache(cache_path))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
@@ -0,0 +1,116 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""LambdaArgument."""
|
||||
from typing import List
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LambdaArgument as _LambdaArgument,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .utils import ACCEPTED_INTS
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class LambdaArgument(WrapperCpp):
|
||||
"""LambdaArgument holds scalar or tensor values."""
|
||||
|
||||
def __init__(self, lambda_argument: _LambdaArgument):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
lambda_argument (_LambdaArgument): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if lambda_argument is not of type _LambdaArgument
|
||||
"""
|
||||
if not isinstance(lambda_argument, _LambdaArgument):
|
||||
raise TypeError(
|
||||
f"lambda_argument must be of type _LambdaArgument, not {type(lambda_argument)}"
|
||||
)
|
||||
super().__init__(lambda_argument)
|
||||
|
||||
@staticmethod
|
||||
def new(*args, **kwargs):
|
||||
"""Use from_scalar or from_tensor instead.
|
||||
|
||||
Raises:
|
||||
RuntimeError
|
||||
"""
|
||||
raise RuntimeError(
|
||||
"you should call from_scalar or from_tensor according to the argument type"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_scalar(scalar: int) -> "LambdaArgument":
|
||||
"""Build a LambdaArgument containing the given scalar value.
|
||||
|
||||
Args:
|
||||
scalar (int or numpy.uint): scalar value to embed in LambdaArgument
|
||||
|
||||
Raises:
|
||||
TypeError: if scalar is not of type int or numpy.uint
|
||||
|
||||
Returns:
|
||||
LambdaArgument
|
||||
"""
|
||||
if not isinstance(scalar, ACCEPTED_INTS):
|
||||
raise TypeError(
|
||||
f"scalar must be of type int or numpy.uint, not {type(scalar)}"
|
||||
)
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar))
|
||||
|
||||
@staticmethod
|
||||
def from_tensor(data: List[int], shape: List[int]) -> "LambdaArgument":
|
||||
"""Build a LambdaArgument containing the given tensor.
|
||||
|
||||
Args:
|
||||
data (List[int]): flattened tensor data
|
||||
shape (List[int]): shape of original tensor before flattening
|
||||
|
||||
Returns:
|
||||
LambdaArgument
|
||||
"""
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_tensor(data, shape))
|
||||
|
||||
def is_scalar(self) -> bool:
|
||||
"""Check if the contained argument is a scalar.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
return self.cpp().is_scalar()
|
||||
|
||||
def get_scalar(self) -> int:
|
||||
"""Return the contained scalar value.
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
return self.cpp().get_scalar()
|
||||
|
||||
def is_tensor(self) -> bool:
|
||||
"""Check if the contained argument is a tensor.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
return self.cpp().is_tensor()
|
||||
|
||||
def get_tensor_shape(self) -> List[int]:
|
||||
"""Return the shape of the contained tensor.
|
||||
|
||||
Returns:
|
||||
List[int]: tensor shape
|
||||
"""
|
||||
return self.cpp().get_tensor_shape()
|
||||
|
||||
def get_tensor_data(self) -> List[int]:
|
||||
"""Return the contained flattened tensor data.
|
||||
|
||||
Returns:
|
||||
List[int]
|
||||
"""
|
||||
return self.cpp().get_tensor_data()
|
||||
@@ -0,0 +1,60 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""LibraryCompilationResult."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LibraryCompilationResult as _LibraryCompilationResult,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class LibraryCompilationResult(WrapperCpp):
|
||||
"""LibraryCompilationResult holds the result of the library compilation."""
|
||||
|
||||
def __init__(self, library_compilation_result: _LibraryCompilationResult):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
library_compilation_result (_LibraryCompilationResult): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if library_compilation_result is not of type _LibraryCompilationResult
|
||||
"""
|
||||
if not isinstance(library_compilation_result, _LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
f"library_compilation_result must be of type _LibraryCompilationResult, not "
|
||||
f"{type(library_compilation_result)}"
|
||||
)
|
||||
super().__init__(library_compilation_result)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(library_path: str, func_name: str) -> "LibraryCompilationResult":
|
||||
"""Build a LibraryCompilationResult at library_path, with func_name as entrypoint.
|
||||
|
||||
Args:
|
||||
library_path (str): path to the library
|
||||
func_name (str): entrypoint function name
|
||||
|
||||
Raises:
|
||||
TypeError: if library_path is not of type str
|
||||
TypeError: if func_name is not of type str
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult
|
||||
"""
|
||||
if not isinstance(library_path, str):
|
||||
raise TypeError(
|
||||
f"library_path must be of type str, not {type(library_path)}"
|
||||
)
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
|
||||
return LibraryCompilationResult.wrap(
|
||||
_LibraryCompilationResult(library_path, func_name)
|
||||
)
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
@@ -0,0 +1,31 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""LibraryLambda."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LibraryLambda as _LibraryLambda,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class LibraryLambda(WrapperCpp):
|
||||
"""LibraryLambda reference a compiled library and can be ran using LibraryLambdaSupport."""
|
||||
|
||||
def __init__(self, library_lambda: _LibraryLambda):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
library_lambda (_LibraryLambda): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if library_lambda is not of type _LibraryLambda
|
||||
"""
|
||||
if not isinstance(library_lambda, _LibraryLambda):
|
||||
raise TypeError(
|
||||
f"library_lambda must be of type _LibraryLambda, not {type(library_lambda)}"
|
||||
)
|
||||
super().__init__(library_lambda)
|
||||
@@ -0,0 +1,202 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""LibraryCompilerSupport.
|
||||
|
||||
Library compilation provides a way to compile an MLIR program into a library that can be later loaded
|
||||
to execute the compiled code.
|
||||
"""
|
||||
import os
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LibraryLambdaSupport as _LibraryLambdaSupport,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .compilation_options import CompilationOptions
|
||||
from .library_compilation_result import LibraryCompilationResult
|
||||
from .public_arguments import PublicArguments
|
||||
from .library_lambda import LibraryLambda
|
||||
from .public_result import PublicResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
# Default output path for compiled libraries
|
||||
DEFAULT_OUTPUT_PATH = os.path.abspath(
|
||||
os.path.join(os.path.curdir, "concrete-compiler_output_lib")
|
||||
)
|
||||
|
||||
|
||||
class LibraryLambdaSupport(WrapperCpp):
|
||||
"""Support class for library compilation and execution."""
|
||||
|
||||
def __init__(self, library_lambda_support: _LibraryLambdaSupport):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
library_lambda_support (_LibraryLambdaSupport): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if library_lambda_support is not of type _LibraryLambdaSupport
|
||||
"""
|
||||
if not isinstance(library_lambda_support, _LibraryLambdaSupport):
|
||||
raise TypeError(
|
||||
f"library_lambda_support must be of type _LibraryLambdaSupport, not "
|
||||
f"{type(library_lambda_support)}"
|
||||
)
|
||||
super().__init__(library_lambda_support)
|
||||
self.library_path = DEFAULT_OUTPUT_PATH
|
||||
|
||||
@property
|
||||
def library_path(self) -> str:
|
||||
"""Path where to store compiled libraries."""
|
||||
return self._library_path
|
||||
|
||||
@library_path.setter
|
||||
def library_path(self, path: str):
|
||||
if not isinstance(path, str):
|
||||
raise TypeError(f"path must be of type str, not {type(path)}")
|
||||
self._library_path = path
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(output_path: str = DEFAULT_OUTPUT_PATH) -> "LibraryLambdaSupport":
|
||||
"""Build a LibraryLambdaSupport.
|
||||
|
||||
Args:
|
||||
output_path (str, optional): path where to store compiled libraries.
|
||||
Defaults to DEFAULT_OUTPUT_PATH.
|
||||
|
||||
Raises:
|
||||
TypeError: if output_path is not of type str
|
||||
|
||||
Returns:
|
||||
LibraryLambdaSupport
|
||||
"""
|
||||
if not isinstance(output_path, str):
|
||||
raise TypeError(f"output_path must be of type str, not {type(output_path)}")
|
||||
library_lambda_support = LibraryLambdaSupport.wrap(
|
||||
_LibraryLambdaSupport(output_path)
|
||||
)
|
||||
library_lambda_support.library_path = output_path
|
||||
return library_lambda_support
|
||||
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions.new("main"),
|
||||
) -> LibraryCompilationResult:
|
||||
"""Compile an MLIR program using Concrete dialects into a library.
|
||||
|
||||
Args:
|
||||
mlir_program (str): textual representation of the mlir program to compile
|
||||
options (CompilationOptions): compilation options
|
||||
|
||||
Raises:
|
||||
TypeError: if mlir_program is not of type str
|
||||
TypeError: if options is not of type CompilationOptions
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: the result of the library compilation
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError(
|
||||
f"mlir_program must be of type str, not {type(mlir_program)}"
|
||||
)
|
||||
if not isinstance(options, CompilationOptions):
|
||||
raise TypeError(
|
||||
f"options must be of type CompilationOptions, not {type(options)}"
|
||||
)
|
||||
return LibraryCompilationResult.wrap(
|
||||
self.cpp().compile(mlir_program, options.cpp())
|
||||
)
|
||||
|
||||
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
|
||||
"""Reload the library compilation result from the library_path.
|
||||
|
||||
Args:
|
||||
func_name: entrypoint function name
|
||||
|
||||
Returns:
|
||||
LibraryCompilationResult: loaded library
|
||||
"""
|
||||
if not isinstance(func_name, str):
|
||||
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
|
||||
return LibraryCompilationResult.new(self.library_path, func_name)
|
||||
|
||||
def load_client_parameters(
|
||||
self, library_compilation_result: LibraryCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the library compilation result.
|
||||
|
||||
Args:
|
||||
library_compilation_result (LibraryCompilationResult): compilation result of the library
|
||||
|
||||
Raises:
|
||||
TypeError: if library_compilation_result is not of type LibraryCompilationResult
|
||||
|
||||
Returns:
|
||||
ClientParameters: appropriate client parameters for the compiled library
|
||||
"""
|
||||
if not isinstance(library_compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
f"library_compilation_result must be of type LibraryCompilationResult, not "
|
||||
f"{type(library_compilation_result)}"
|
||||
)
|
||||
|
||||
return ClientParameters.wrap(
|
||||
self.cpp().load_client_parameters(library_compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_server_lambda(
|
||||
self, library_compilation_result: LibraryCompilationResult
|
||||
) -> LibraryLambda:
|
||||
"""Load the server lambda from the library compilation result.
|
||||
|
||||
Args:
|
||||
library_compilation_result (LibraryCompilationResult): compilation result of the library
|
||||
|
||||
Raises:
|
||||
TypeError: if library_compilation_result is not of type LibraryCompilationResult
|
||||
|
||||
Returns:
|
||||
LibraryLambda: executable reference to the library
|
||||
"""
|
||||
if not isinstance(library_compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
f"library_compilation_result must be of type LibraryCompilationResult, not "
|
||||
f"{type(library_compilation_result)}"
|
||||
)
|
||||
return LibraryLambda.wrap(
|
||||
self.cpp().load_server_lambda(library_compilation_result.cpp())
|
||||
)
|
||||
|
||||
def server_call(
|
||||
self, library_lambda: LibraryLambda, public_arguments: PublicArguments
|
||||
) -> PublicResult:
|
||||
"""Call the library with public_arguments.
|
||||
|
||||
Args:
|
||||
library_lambda (LibraryLambda): reference to the compiled library
|
||||
public_arguments (PublicArguments): arguments to use for execution
|
||||
|
||||
Raises:
|
||||
TypeError: if library_lambda is not of type LibraryLambda
|
||||
TypeError: if public_arguments is not of type PublicArguments
|
||||
|
||||
Returns:
|
||||
PublicResult: result of the execution
|
||||
"""
|
||||
if not isinstance(library_lambda, LibraryLambda):
|
||||
raise TypeError(
|
||||
f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}"
|
||||
)
|
||||
if not isinstance(public_arguments, PublicArguments):
|
||||
raise TypeError(
|
||||
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
|
||||
)
|
||||
return PublicResult.wrap(
|
||||
self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp())
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""PublicArguments."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
PublicArguments as _PublicArguments,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class PublicArguments(WrapperCpp):
|
||||
"""PublicArguments holds encrypted and plain arguments, as well as public materials.
|
||||
|
||||
An encrypted computation may require both encrypted and plain arguments, PublicArguments holds both
|
||||
types, but also other public materials, such as public keys, which are required for private computation.
|
||||
"""
|
||||
|
||||
def __init__(self, public_arguments: _PublicArguments):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
public_arguments (_PublicArguments): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if public_arguments is not of type _PublicArguments
|
||||
"""
|
||||
if not isinstance(public_arguments, _PublicArguments):
|
||||
raise TypeError(
|
||||
f"public_arguments must be of type _PublicArguments, not {type(public_arguments)}"
|
||||
)
|
||||
super().__init__(public_arguments)
|
||||
@@ -0,0 +1,31 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""PublicResult."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
PublicResult as _PublicResult,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class PublicResult(WrapperCpp):
|
||||
"""PublicResult holds the result of an encrypted execution and can be decrypted using ClientSupport."""
|
||||
|
||||
def __init__(self, public_result: _PublicResult):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
public_result (_PublicResult): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if public_result is not of type _PublicResult
|
||||
"""
|
||||
if not isinstance(public_result, _PublicResult):
|
||||
raise TypeError(
|
||||
f"public_result must be of type _PublicResult, not {type(public_result)}"
|
||||
)
|
||||
super().__init__(public_result)
|
||||
35
compiler/lib/Bindings/Python/concrete/compiler/utils.py
Normal file
35
compiler/lib/Bindings/Python/concrete/compiler/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""Common utils for the compiler submodule."""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
|
||||
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
|
||||
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
|
||||
|
||||
|
||||
def lookup_runtime_lib() -> str:
|
||||
"""Try to find the absolute path to the runtime library.
|
||||
|
||||
Returns:
|
||||
str: absolute path to the runtime library, or empty str if unsuccessful.
|
||||
"""
|
||||
# Go up to site-packages level
|
||||
cwd = os.path.abspath(__file__)
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
|
||||
package_name = "concrete_compiler"
|
||||
libs_path = os.path.join(cwd, f"{package_name}.libs")
|
||||
# Can be because it's not a properly installed package
|
||||
if not os.path.exists(libs_path):
|
||||
return ""
|
||||
runtime_library_paths = [
|
||||
filename
|
||||
for filename in os.listdir(libs_path)
|
||||
if filename.startswith("libConcretelangRuntime")
|
||||
]
|
||||
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
return os.path.join(libs_path, runtime_library_paths[0])
|
||||
39
compiler/lib/Bindings/Python/concrete/compiler/wrapper.py
Normal file
39
compiler/lib/Bindings/Python/concrete/compiler/wrapper.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""Wrapper for native Cpp objects."""
|
||||
|
||||
|
||||
class WrapperCpp:
|
||||
"""Wrapper base class for native Cpp objects.
|
||||
|
||||
Initialization should mainly store the wrapped object, and future calls to the wrapper will be forwarded
|
||||
to it. A static wrap method is provided to be more explicit. Wrappers should always be constructed using
|
||||
the new method, which construct the Cpp object using the provided arguments, then wrap it. Classes that
|
||||
inherit from this class should preferably type check the wrapped object during calls to init, and
|
||||
reimplement the new method if the class is meant to be constructed.
|
||||
"""
|
||||
|
||||
def __init__(self, cpp_obj):
|
||||
self._cpp_obj = cpp_obj
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, cpp_obj) -> "WrapperCpp":
|
||||
"""Wrap the Cpp object into a Python object.
|
||||
|
||||
Args:
|
||||
cpp_obj: object to wrap
|
||||
|
||||
Returns:
|
||||
WrapperCpp: wrapper
|
||||
"""
|
||||
return cls(cpp_obj)
|
||||
|
||||
@staticmethod
|
||||
def new(*args, **kwargs):
|
||||
"""Create a new wrapper by building the underlying object with a specific set of arguments."""
|
||||
raise RuntimeError("This class shouldn't be built")
|
||||
|
||||
def cpp(self):
|
||||
"""Return the Cpp wrapped object."""
|
||||
return self._cpp_obj
|
||||
Reference in New Issue
Block a user