diff --git a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index faeae19da..0c0e85864 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -20,6 +20,9 @@ from .public_result import PublicResult from .public_arguments import PublicArguments from .jit_compilation_result import JITCompilationResult from .jit_lambda import JITLambda +from .lambda_argument import LambdaArgument +from .library_compilation_result import LibraryCompilationResult +from .library_lambda import LibraryLambda from .client_support import ClientSupport from .jit_support import JITSupport from .library_support import LibrarySupport diff --git a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py index b4ad5c8f6..126ce7300 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py @@ -40,7 +40,7 @@ class ClientSupport(WrapperCpp): """ if not isinstance(client_support, _ClientSupport): raise TypeError( - f"client_support must be of type _ClientSupport not {type(client_support)}" + f"client_support must be of type _ClientSupport, not {type(client_support)}" ) super().__init__(client_support) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py b/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py index 4830c1b5b..aedf35103 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py @@ -39,7 +39,7 @@ class JITSupport(WrapperCpp): """ if not isinstance(jit_support, _JITSupport): raise TypeError( - f"jit_support must be of type _JITSupport not{type(jit_support)}" + f"jit_support must be of type _JITSupport, not {type(jit_support)}" ) super().__init__(jit_support) diff --git a/compiler/tests/python/__init__.py b/compiler/tests/python/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/tests/python/conftest.py b/compiler/tests/python/conftest.py index 30e5ed38c..99765d80a 100644 --- a/compiler/tests/python/conftest.py +++ b/compiler/tests/python/conftest.py @@ -1,2 +1,15 @@ +import os +import tempfile +import pytest +from concrete.compiler import KeySetCache + +KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") + + def pytest_configure(config): config.addinivalue_line("markers", "parallel: mark parallel tests") + + +@pytest.fixture(scope="session") +def keyset_cache(): + return KeySetCache.new(KEY_SET_CACHE_PATH) diff --git a/compiler/tests/python/test_argument_support.py b/compiler/tests/python/test_argument_support.py new file mode 100644 index 000000000..75e0d8811 --- /dev/null +++ b/compiler/tests/python/test_argument_support.py @@ -0,0 +1,79 @@ +import pytest +import numpy as np +from concrete.compiler.utils import ACCEPTED_NUMPY_UINTS +from concrete.compiler import ClientSupport + + +@pytest.mark.parametrize( + "garbage", + [ + pytest.param(None, id="None"), + pytest.param([0, 1, 2], id="list"), + pytest.param(0.5, id="float"), + pytest.param(2**70, id="large int"), + pytest.param(-8, id="negative int"), + pytest.param("aze", id="str"), + pytest.param(np.float64(0.8), id="np.float64"), + pytest.param(np.int8(9), id="np.int8"), + pytest.param(np.array([1, 2, 3], dtype=np.int64), id="np.array(np.int64)"), + ], +) +def test_invalid_arg_type(garbage): + with pytest.raises(TypeError): + ClientSupport._create_lambda_argument(garbage) + + +@pytest.mark.parametrize( + "value", + [ + pytest.param(5, id="int"), + pytest.param(np.uint8(5), id="uint8"), + pytest.param(np.uint16(7), id="uint16"), + pytest.param(np.uint32(9), id="uint32"), + pytest.param(np.uint64(1), id="uint64"), + ], +) +def test_accepted_ints(value): + try: + arg = ClientSupport._create_lambda_argument(value) + except Exception: + pytest.fail(f"value of type {type(value)} should be supported") + assert arg.is_scalar(), "should have been a scalar" + assert arg.get_scalar() == value + + +# TODO: #495 +# @pytest.mark.parametrize( +# "dtype", +# [ +# pytest.param(np.uint8, id="uint8"), +# pytest.param(np.uint16, id="uint16"), +# pytest.param(np.uint32, id="uint32"), +# pytest.param(np.uint64, id="uint64"), +# ], +# ) +# def test_accepted_ndarray(dtype): +# value = np.array([0, 1, 2], dtype=dtype) +# try: +# arg = ClientSupport._create_lambda_argument(value) +# except Exception: +# pytest.fail(f"value of type {type(value)} should be supported") + +# assert arg.is_tensor(), "should have been a tensor" +# assert np.all(np.equal(arg.get_tensor_shape(), value.shape)) +# assert np.all( +# np.equal( +# value, +# np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()), +# ) +# ) + + +def test_accepted_array_as_scalar(): + value = np.array(7, dtype=np.uint16) + try: + arg = ClientSupport._create_lambda_argument(value) + except Exception: + pytest.fail(f"value of type {type(value)} should be supported") + assert arg.is_scalar(), "should have been a scalar" + assert arg.get_scalar() == value diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py new file mode 100644 index 000000000..7463133a2 --- /dev/null +++ b/compiler/tests/python/test_compilation.py @@ -0,0 +1,295 @@ +from typing import Union +import pytest +import numpy as np +from concrete.compiler import ( + JITSupport, + LibrarySupport, + ClientSupport, + CompilationOptions, +) + + +def assert_result(result, expected_result): + """Assert that result and expected result are equal. + + result and expected_result can be integers on numpy arrays. + """ + assert type(expected_result) == type(result) + if isinstance(expected_result, int): + assert result == expected_result + else: + assert np.all(result == expected_result) + + +def run(engine, args, compilation_result, keyset_cache): + """Execute engine on the given arguments. + + Perform required loading, encryption, execution, and decryption.""" + # Client + client_parameters = engine.load_client_parameters(compilation_result) + key_set = ClientSupport.key_set(client_parameters, keyset_cache) + public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) + # Server + server_lambda = engine.load_server_lambda(compilation_result) + public_result = engine.server_call(server_lambda, public_arguments) + # Client + result = ClientSupport.decrypt_result(key_set, public_result) + return result + + +def compile_run_assert( + engine, + mlir_input, + args, + expected_result, + keyset_cache, + options=CompilationOptions.new("main"), +): + """Compile run and assert result. + + Can take both JITSupport or LibrarySupport as engine. + """ + compilation_result = engine.compile(mlir_input, options) + result = run(engine, args, compilation_result, keyset_cache) + assert_result(result, expected_result) + + +end_to_end_fixture = [ + pytest.param( + """ + func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (5, 7), + 12, + id="add_eint_int", + ), + pytest.param( + """ + func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)), + 9, + id="add_eint_int_with_ndarray_as_scalar", + ), + pytest.param( + """ + func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { + %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (73,), + 73, + id="apply_lookup_table", + ), + pytest.param( + """ + func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> + { + %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> + return %ret : !FHE.eint<7> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([4, 3, 2, 1], dtype=np.uint8), + ), + 20, + id="dot_eint_int_uint8", + ), + pytest.param( + """ + func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } + """, + ( + np.array([31, 6, 12, 9], dtype=np.uint8), + np.array([32, 9, 2, 3], dtype=np.uint8), + ), + np.array([63, 15, 14, 12]), + id="add_eint_int_1D", + ), +] + + +@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) +def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache): + engine = JITSupport.new() + compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache) + + +@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) +def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache): + engine = LibrarySupport.new("./py_test_lib_compile_and_run") + compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache) + + +@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) +def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_cache): + engine = LibrarySupport.new("./test_lib_compile_reload_and_run") + # Here don't save compilation result, reload + engine.compile(mlir_input) + compilation_result = engine.reload() + result = run(engine, args, compilation_result, keyset_cache) + # Check result + assert_result(result, expected_result) + + +@pytest.mark.parallel +@pytest.mark.parametrize( + "mlir_input, args, expected_result", + [ + pytest.param( + """ + func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x4x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { + %c = arith.constant dense<[[1, 2], [3, 4], [5, 0], [1, 2]]> : tensor<4x2xi8> + %0 = "FHELinalg.matmul_eint_int"(%x, %c): (tensor<3x4x!FHE.eint<7>>, tensor<4x2xi8>) -> tensor<3x2x!FHE.eint<7>> + %1 = "FHELinalg.matmul_eint_int"(%y, %c): (tensor<3x4x!FHE.eint<7>>, tensor<4x2xi8>) -> tensor<3x2x!FHE.eint<7>> + %2 = "FHELinalg.add_eint"(%0, %1): (tensor<3x2x!FHE.eint<7>>, tensor<3x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + return %2 : tensor<3x2x!FHE.eint<7>> + } + """, + ( + np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8), + ), + np.array([[52, 36], [31, 34], [42, 52]]), + id="matmul_eint_int_uint8", + ), + pytest.param( + """ + func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>, %a2: tensor<4x!FHE.eint<6>>, %a3: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { + %1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> + %2 = "FHELinalg.add_eint_int"(%a2, %a3) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> + %res = "FHELinalg.add_eint"(%1, %2) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([9, 8, 6, 5], dtype=np.uint8), + np.array([3, 2, 7, 0], dtype=np.uint8), + np.array([1, 4, 2, 11], dtype=np.uint8), + ), + np.array([14, 16, 18, 20]), + id="add_eint_int_1D", + ), + ], +) +@pytest.mark.parametrize( + "EngineClass", + [ + pytest.param(JITSupport, id="JIT"), + pytest.param(LibrarySupport, id="Library"), + ], +) +def test_compile_and_run_auto_parallelize( + mlir_input, args, expected_result, keyset_cache, EngineClass +): + engine = EngineClass.new() + options = CompilationOptions.new("main") + options.set_auto_parallelize(True) + compile_run_assert( + engine, mlir_input, args, expected_result, keyset_cache, options=options + ) + + +@pytest.mark.parametrize( + "mlir_input, args, expected_result", + [ + pytest.param( + """ + func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { + %y = arith.constant dense<[[1, 2], [3, 4], [5, 0], [1, 2]]> : tensor<4x2xi8> + %0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<3x4x!FHE.eint<7>>, tensor<4x2xi8>) -> tensor<3x2x!FHE.eint<7>> + return %0 : tensor<3x2x!FHE.eint<7>> + } + """, + (np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8),), + np.array([[26, 18], [15, 16], [21, 26]]), + id="matmul_eint_int_uint8", + ), + ], +) +@pytest.mark.parametrize( + "EngineClass", + [ + pytest.param(JITSupport, id="JIT"), + pytest.param(LibrarySupport, id="Library"), + ], +) +def test_compile_and_run_loop_parallelize( + mlir_input, args, expected_result, keyset_cache, EngineClass +): + engine = EngineClass.new() + options = CompilationOptions.new("main") + options.set_loop_parallelize(True) + compile_run_assert( + engine, mlir_input, args, expected_result, keyset_cache, options=options + ) + + +@pytest.mark.parametrize( + "mlir_input, args", + [ + pytest.param( + """ + func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (5, 7, 8), + id="add_eint_int_invalid_arg_number", + ), + ], +) +@pytest.mark.parametrize( + "EngineClass", + [ + pytest.param(JITSupport, id="JIT"), + pytest.param(LibrarySupport, id="Library"), + ], +) +def test_compile_and_run_invalid_arg_number( + mlir_input, args, EngineClass, keyset_cache +): + engine = EngineClass.new() + with pytest.raises( + RuntimeError, match=r"function has arity 2 but is applied to too many arguments" + ): + compile_run_assert(engine, mlir_input, args, None, keyset_cache) + + +@pytest.mark.parametrize( + "mlir_input", + [ + pytest.param( + """ + func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> + { + %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> + return %ret : !FHE.eint<7> + } + """, + id="not @main", + ), + ], +) +def test_compile_invalid(mlir_input): + engine = JITSupport.new() + with pytest.raises( + RuntimeError, match=r"cannot find the function for generate client parameters" + ): + engine.compile(mlir_input) diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py deleted file mode 100644 index 7380af2d9..000000000 --- a/compiler/tests/python/test_compiler_engine.py +++ /dev/null @@ -1,344 +0,0 @@ -import os -import tempfile - -import pytest -import numpy as np -from concrete.compiler import JITSupport, LibrarySupport, ClientSupport, KeySetCache - -KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") - -keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH) - - -def compile_and_run(engine, mlir_input, args, expected_result): - compilation_result = engine.compile(mlir_input) - # Client - client_parameters = engine.load_client_parameters(compilation_result) - key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) - # Server - server_lambda = engine.load_server_lambda(compilation_result) - public_result = engine.server_call(server_lambda, public_arguments) - # Client - result = ClientSupport.decrypt_result(key_set, public_result) - # Check result - assert type(expected_result) == type(result) - if isinstance(expected_result, int): - assert result == expected_result - else: - assert np.all(result == expected_result) - - -end_to_end_fixture = [ - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (5, 7), - 12, - id="add_eint_int", - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)), - 9, - id="add_eint_int_with_ndarray_as_scalar", - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (np.uint8(3), np.uint8(5)), - 8, - id="add_eint_int_with_np_uint8_as_scalar", - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (np.uint16(3), np.uint16(5)), - 8, - id="add_eint_int_with_np_uint16_as_scalar", - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (np.uint32(3), np.uint32(5)), - 8, - id="add_eint_int_with_np_uint32_as_scalar", - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (np.uint64(3), np.uint64(5)), - 8, - id="add_eint_int_with_np_uint64_as_scalar", - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> - %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (73,), - 73, - id="apply_lookup_table", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), - ), - 20, - id="dot_eint_int_uint8", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint16), - np.array([4, 3, 2, 1], dtype=np.uint16), - ), - 20, - id="dot_eint_int_uint16", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint32), - np.array([4, 3, 2, 1], dtype=np.uint32), - ), - 20, - id="dot_eint_int_uint32", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint64), - np.array([4, 3, 2, 1], dtype=np.uint64), - ), - 20, - id="dot_eint_int_uint64", - ), - pytest.param( - """ - func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { - %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> - return %res : tensor<4x!FHE.eint<6>> - } - """, - ( - np.array([31, 6, 12, 9], dtype=np.uint8), - np.array([32, 9, 2, 3], dtype=np.uint8), - ), - np.array([63, 15, 14, 12]), - id="add_eint_int_1D", - ), - pytest.param( - """ - func @main(%a0: tensor<4x4x!FHE.eint<6>>, %a1: tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>> { - %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x4x!FHE.eint<6>>, tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>> - return %res : tensor<4x4x!FHE.eint<6>> - } - """, - ( - np.array( - [[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]], - dtype=np.uint8, - ), - np.array( - [[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]], - dtype=np.uint8, - ), - ), - np.array( - [ - [63, 15, 14, 12], - [63, 15, 14, 12], - [63, 15, 14, 12], - [63, 15, 14, 12], - ], - dtype=np.uint8, - ), - id="add_eint_int_2D", - ), - pytest.param( - """ - func @main(%a0: tensor<2x2x2x!FHE.eint<6>>, %a1: tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>> { - %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x2x!FHE.eint<6>>, tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>> - return %res : tensor<2x2x2x!FHE.eint<6>> - } - """, - ( - np.array( - [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], - dtype=np.uint8, - ), - np.array( - [[[9, 10], [11, 12]], [[13, 14], [15, 16]]], - dtype=np.uint8, - ), - ), - np.array( - [[[10, 12], [14, 16]], [[18, 20], [22, 24]]], - dtype=np.uint8, - ), - id="add_eint_int_3D", - ), -] - - -@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) -def test_jit_compile_and_run(mlir_input, args, expected_result): - engine = JITSupport.new() - compile_and_run(engine, mlir_input, args, expected_result) - - -@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) -def test_lib_compile_and_run(mlir_input, args, expected_result): - engine = LibrarySupport.new("./py_test_lib_compile_and_run") - compile_and_run(engine, mlir_input, args, expected_result) - - -@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) -def test_lib_compile_reload_and_run(mlir_input, args, expected_result): - engine = LibrarySupport.new("./test_lib_compile_reload_and_run") - # Here don't save compilation result, reload - engine.compile(mlir_input) - compilation_result = engine.reload() - # Client - client_parameters = engine.load_client_parameters(compilation_result) - key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) - # Server - server_lambda = engine.load_server_lambda(compilation_result) - public_result = engine.server_call(server_lambda, public_arguments) - # Client - result = ClientSupport.decrypt_result(key_set, public_result) - # Check result - assert type(expected_result) == type(result) - if isinstance(expected_result, int): - assert result == expected_result - else: - assert np.all(result == expected_result) - - -@pytest.mark.parametrize( - "mlir_input, args", - [ - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (5, 7, 8), - id="add_eint_int_invalid_arg_number", - ), - ], -) -def test_compile_and_run_invalid_arg_number(mlir_input, args): - engine = JITSupport.new() - with pytest.raises( - RuntimeError, match=r"function has arity 2 but is applied to too many arguments" - ): - compile_and_run(engine, mlir_input, args, None) - - -@pytest.mark.parametrize( - "mlir_input, args, expected_result", - [ - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> - %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (73,), - 73, - id="apply_lookup_table", - ), - ], -) -def test_compile_and_run_tlu(mlir_input, args, expected_result): - engine = JITSupport.new() - compile_and_run(engine, mlir_input, args, expected_result) - - -@pytest.mark.parametrize( - "mlir_input", - [ - pytest.param( - """ - func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - id="not @main", - ), - ], -) -def test_compile_invalid(mlir_input): - engine = JITSupport.new() - with pytest.raises( - RuntimeError, match=r"cannot find the function for generate client parameters" - ): - engine.compile(mlir_input) diff --git a/compiler/tests/python/test_compiler_engine_parallel.py b/compiler/tests/python/test_compiler_engine_parallel.py deleted file mode 100644 index ac5a3196e..000000000 --- a/compiler/tests/python/test_compiler_engine_parallel.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import tempfile - -import pytest -import numpy as np - -from concrete.compiler import ClientSupport, CompilationOptions, JITSupport, KeySetCache - -KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") - -keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH) - - -def compile_and_run( - engine, mlir_input, args, expected_result, options=CompilationOptions.new("main") -): - compilation_result = engine.compile(mlir_input, options) - # Client - client_parameters = engine.load_client_parameters(compilation_result) - key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) - # Server - server_lambda = engine.load_server_lambda(compilation_result) - public_result = engine.server_call(server_lambda, public_arguments) - # Client - result = ClientSupport.decrypt_result(key_set, public_result) - # Check result - assert type(expected_result) == type(result) - if isinstance(expected_result, int): - assert result == expected_result - else: - assert np.all(result == expected_result) - - -@pytest.mark.parallel -@pytest.mark.parametrize( - "mlir_input, args, expected_result", - [ - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (5, 7), - 12, - id="add_eint_int", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), - ), - 20, - id="dot_eint_int_uint8", - ), - ], -) -def test_compile_and_run_auto_parallelize(mlir_input, args, expected_result): - engine = JITSupport.new() - options = CompilationOptions.new("main") - options.set_auto_parallelize(True) - compile_and_run(engine, mlir_input, args, expected_result, options=options) - - -@pytest.mark.parametrize( - "mlir_input, args, expected_result", - [ - pytest.param( - """ - func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> - } - """, - (5, 7), - 12, - id="add_eint_int", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7> - return %ret : !FHE.eint<7> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), - ), - 20, - id="dot_eint_int_uint8", - ), - ], -) -def test_compile_and_run_loop_parallelize(mlir_input, args, expected_result): - engine = JITSupport.new() - options = CompilationOptions.new("main") - options.set_loop_parallelize(True) - compile_and_run(engine, mlir_input, args, expected_result, options=options) diff --git a/compiler/tests/python/test_utils.py b/compiler/tests/python/test_utils.py new file mode 100644 index 000000000..2f3034ac1 --- /dev/null +++ b/compiler/tests/python/test_utils.py @@ -0,0 +1,17 @@ +import re +import importlib.util +from concrete.compiler.utils import lookup_runtime_lib + + +def test_runtime_lib_path(): + # runtime library path should be found in case the package is installed + compiler_spec = importlib.util.find_spec("concrete.compiler") + # assuming installed packages should have python and site-packages as part of the path + if compiler_spec and re.match(r".*python.*site-packages.*", compiler_spec.origin): + runtime_lib_path = lookup_runtime_lib() + assert isinstance( + runtime_lib_path, str + ), f"runtime library path should be of type str, not {type(runtime_lib_path)}" + assert re.match( + r".*libConcretelangRuntime.*\.(so|dylib)$", runtime_lib_path + ), f"wrong runtime library path: {runtime_lib_path}" diff --git a/compiler/tests/python/test_wrappers.py b/compiler/tests/python/test_wrappers.py new file mode 100644 index 000000000..3e1ad18a0 --- /dev/null +++ b/compiler/tests/python/test_wrappers.py @@ -0,0 +1,45 @@ +import pytest +from concrete.compiler import ( + ClientParameters, + ClientSupport, + CompilationOptions, + JITCompilationResult, + JITLambda, + JITSupport, + KeySetCache, + KeySet, + LambdaArgument, + LibraryCompilationResult, + LibraryLambda, + LibrarySupport, + PublicArguments, + PublicResult, +) + + +@pytest.mark.parametrize("garbage", ["string here", 23, None]) +@pytest.mark.parametrize( + "WrapperClass", + [ + pytest.param(ClientParameters, id="ClientParameters"), + pytest.param(ClientSupport, id="ClientSupport"), + pytest.param(CompilationOptions, id="CompilationOptions"), + pytest.param(JITCompilationResult, id="JITCompilationResult"), + pytest.param(JITLambda, id="JITLambda"), + pytest.param(JITSupport, id="JITSupport"), + pytest.param(KeySetCache, id="KeySetCache"), + pytest.param(KeySet, id="KeySet"), + pytest.param(LambdaArgument, id="LambdaArgument"), + pytest.param(LibraryCompilationResult, id="LibraryCompilationResult"), + pytest.param(LibraryLambda, id="LibraryLambda"), + pytest.param(LibrarySupport, id="LibrarySupport"), + pytest.param(PublicArguments, id="PublicArguments"), + pytest.param(PublicResult, id="PublicResult"), + ], +) +def test_invalid_wrapping(WrapperClass, garbage): + with pytest.raises( + TypeError, + match=f"\.* must be of type _{WrapperClass.__name__}, not {type(garbage)}", + ): + WrapperClass(garbage)