From 1b984f5119bafb2dbc2c15f4a259efd3ca08c05c Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 16 Mar 2022 14:39:31 +0100 Subject: [PATCH] cleanup(python): Cleanup python bindings after refactoring of CompilerEngine --- .../lib/Bindings/Python/concrete/compiler.py | 115 ++++++------------ compiler/tests/python/test_compiler_engine.py | 58 +-------- 2 files changed, 43 insertions(+), 130 deletions(-) diff --git a/compiler/lib/Bindings/Python/concrete/compiler.py b/compiler/lib/Bindings/Python/concrete/compiler.py index 0cbda438c..eb059ffae 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler.py +++ b/compiler/lib/Bindings/Python/concrete/compiler.py @@ -7,15 +7,10 @@ import os import atexit from typing import List, Union -from mlir._mlir_libs._concretelang._compiler import ( - JitCompilerEngine as _JitCompilerEngine, - terminate_parallelization as _terminate_parallelization, -) -from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument +from mlir._mlir_libs._concretelang._compiler import terminate_parallelization as _terminate_parallelization + from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip -from mlir._mlir_libs._concretelang._compiler import library as _library -from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport -from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport + from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport from mlir._mlir_libs._concretelang._compiler import ClientParameters @@ -25,10 +20,13 @@ from mlir._mlir_libs._concretelang._compiler import KeySetCache from mlir._mlir_libs._concretelang._compiler import PublicResult from mlir._mlir_libs._concretelang._compiler import PublicArguments +from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument +from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport from mlir._mlir_libs._concretelang._compiler import JitCompilationResult from mlir._mlir_libs._concretelang._compiler import JITLambda +from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport as _LibraryLambdaSupport from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult from mlir._mlir_libs._concretelang._compiler import LibraryLambda import numpy as np @@ -85,70 +83,6 @@ def round_trip(mlir_str: str) -> str: return _round_trip(mlir_str) -_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str' - - -def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str: - """Compile the MLIR inputs to a library. - - Args: - library_path (str): destination path of the library - mlir_modules (list[str]|str): code of MLIR modules - - Raises: - TypeError: if arguments have incorrect types. - - Returns: - str: parsed MLIR input. - """ - if not isinstance(library_path, str): - raise TypeError("library_path must be a `str`") - if isinstance(mlir_modules, str): - mlir_modules = [mlir_modules] - elif isinstance(mlir_modules, list): - pass - elif isinstance(mlir_modules, Iterable): - mlir_modules = list(mlir_modules) - else: - mlir_modules = [None] - raise TypeError(_MLIR_MODULES_TYPE) - - if not all(isinstance(m, str) for m in mlir_modules): - raise TypeError(_MLIR_MODULES_TYPE) - - return _library(library_path, mlir_modules) - - -def create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument: - """Create an execution argument holding either an int or tensor value. - - Args: - value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array - - Raises: - TypeError: if the values aren't in the expected range, or using a wrong type - - Returns: - _LambdaArgument: lambda argument holding the appropriate value - """ - if not isinstance(value, ACCEPTED_TYPES): - raise TypeError( - "value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}") - if isinstance(value, ACCEPTED_INTS): - if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max): - raise TypeError( - "single integer must be in the range [0, 2**64 - 1] (uint64)" - ) - return _LambdaArgument.from_scalar(value) - else: - assert isinstance(value, np.ndarray) - if value.shape == (): - return _LambdaArgument.from_scalar(value) - if value.dtype not in ACCEPTED_NUMPY_UINTS: - raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") - return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape) - - class CompilerEngine: def __init__(self, mlir_str: str = None): self._engine = JITCompilerSupport() @@ -264,7 +198,8 @@ class ClientSupport: Returns: PublicArguments: the public arguments """ - execution_arguments = [create_execution_argument(arg) for arg in args] + execution_arguments = [ + ClientSupport._create_execution_argument(arg) for arg in args] return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments) def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]: @@ -287,12 +222,42 @@ class ClientSupport: else: raise RuntimeError("unknown return type") + def _create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument: + """Create an execution argument holding either an int or tensor value. + + Args: + value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array + + Raises: + TypeError: if the values aren't in the expected range, or using a wrong type + + Returns: + _LambdaArgument: lambda argument holding the appropriate value + """ + if not isinstance(value, ACCEPTED_TYPES): + raise TypeError( + "value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}") + if isinstance(value, ACCEPTED_INTS): + if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max): + raise TypeError( + "single integer must be in the range [0, 2**64 - 1] (uint64)" + ) + return _LambdaArgument.from_scalar(value) + else: + assert isinstance(value, np.ndarray) + if value.shape == (): + return _LambdaArgument.from_scalar(value) + if value.dtype not in ACCEPTED_NUMPY_UINTS: + raise TypeError( + "numpy.array must be of dtype uint{8,16,32,64}") + return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape) + class JITCompilerSupport: def __init__(self, runtime_lib_path=None): if runtime_lib_path is None: runtime_lib_path = _lookup_runtime_lib() - self._support = JITLambdaSupport(runtime_lib_path) + self._support = _JITLambdaSupport(runtime_lib_path) def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult: """JIT Compile a function define in the mlir_program to its homomorphic equivalent. @@ -332,7 +297,7 @@ class JITCompilerSupport: class LibraryCompilerSupport: def __init__(self, outputPath="./out"): self._library_path = outputPath - self._support = LibraryLambdaSupport(outputPath) + self._support = _LibraryLambdaSupport(outputPath) def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult: """Compile a function define in the mlir_program to its homomorphic equivalent and save as library. diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 9cb72f894..4064beea3 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -3,7 +3,7 @@ import tempfile import pytest import numpy as np -from concrete.compiler import CompilerEngine, library +from concrete.compiler import CompilerEngine from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport from lib.Bindings.Python.concrete.compiler import ClientSupport from lib.Bindings.Python.concrete.compiler import KeySetCache @@ -313,7 +313,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args): engine.run(*args) -@pytest.mark.parametrize( +@ pytest.mark.parametrize( "mlir_input, args, expected_result, tab_size", [ pytest.param( @@ -338,7 +338,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): assert abs(engine.run(*args) - expected_result) / tab_size < 0.1 -@pytest.mark.parametrize( +@ pytest.mark.parametrize( "mlir_input", [ pytest.param( @@ -359,55 +359,3 @@ def test_compile_invalid(mlir_input): with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"): engine.compile_fhe( mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) - - -MODULE_1 = """ -func @test1() -{ - return -} -""" -MODULE_2 = """ -func @test2() -{ - return -} -""" -LIB_PATH = './test_library_generation.so' - -@pytest.mark.parametrize( - 'mlir_modules', - [ - pytest.param(MODULE_1, id='1 module'), - pytest.param([MODULE_1, MODULE_2], id='2 modules'), - pytest.param(iter([MODULE_1, MODULE_2]), id='iterable'), - ], -) -def test_library_generation(mlir_modules): - library_path = library(LIB_PATH, mlir_modules) - assert os.path.exists(library_path) - - -@pytest.mark.parametrize( - 'mlir_modules', - [ - pytest.param(bytes(MODULE_1, encoding='utf-8'), id='bytes vs str'), - pytest.param(None, id='not iterable'), - pytest.param([None], id='not str'), - ], -) -def test_library_generation_type_error(mlir_modules): - with pytest.raises(TypeError): - library(LIB_PATH, mlir_modules) - - -def test_library_call(): - module = """ - func @test(%a: i8) -> i8 - { - return %a : i8 - } - """ - from ctypes import CDLL - lib = CDLL(library(LIB_PATH, module)) - assert lib.test(13) == 13