From edd10c1436ce64e550cf71e27ca6d7f61e337aa3 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 31 Mar 2022 12:10:55 +0100 Subject: [PATCH] test: update tests to the refactored API --- compiler/tests/python/test_compiler_engine.py | 34 +++++++-------- .../python/test_compiler_engine_parallel.py | 42 +++++++++++++------ 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index d6ced1e36..f36f2067a 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -3,21 +3,20 @@ import tempfile import pytest import numpy as np -from concrete.compiler import CompilerEngine -from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport +from concrete.compiler import JITLambdaSupport, LibraryLambdaSupport from concrete.compiler import ClientSupport from concrete.compiler import KeySetCache KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") -keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH) +keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH) def compile_and_run(engine, mlir_input, args, expected_result): compilation_result = engine.compile(mlir_input) # Client client_parameters = engine.load_client_parameters(compilation_result) - key_set = ClientSupport.key_set(client_parameters, keySetCacheTest) + key_set = ClientSupport.key_set(client_parameters, keyset_cache) public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) # Server server_lambda = engine.load_server_lambda(compilation_result) @@ -245,25 +244,25 @@ end_to_end_fixture = [ @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_jit_compile_and_run(mlir_input, args, expected_result): - engine = JITCompilerSupport() + engine = JITLambdaSupport.new() compile_and_run(engine, mlir_input, args, expected_result) @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_and_run(mlir_input, args, expected_result): - engine = LibraryCompilerSupport("py_test_lib_compile_and_run") + engine = LibraryLambdaSupport.new("py_test_lib_compile_and_run") compile_and_run(engine, mlir_input, args, expected_result) @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_reload_and_run(mlir_input, args, expected_result): - engine = LibraryCompilerSupport("test_lib_compile_reload_and_run") + engine = LibraryLambdaSupport.new("test_lib_compile_reload_and_run") # Here don't save compilation result, reload engine.compile(mlir_input) compilation_result = engine.reload() # Client client_parameters = engine.load_client_parameters(compilation_result) - key_set = ClientSupport.key_set(client_parameters, keySetCacheTest) + key_set = ClientSupport.key_set(client_parameters, keyset_cache) public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) # Server server_lambda = engine.load_server_lambda(compilation_result) @@ -294,16 +293,15 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result): ], ) def test_compile_and_run_invalid_arg_number(mlir_input, args): - engine = CompilerEngine() - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) + engine = JITLambdaSupport.new() with pytest.raises( RuntimeError, match=r"function has arity 2 but is applied to too many arguments" ): - engine.run(*args) + compile_and_run(engine, mlir_input, args, None) @pytest.mark.parametrize( - "mlir_input, args, expected_result, tab_size", + "mlir_input, args, expected_result", [ pytest.param( """ @@ -315,15 +313,13 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args): """, (73,), 73, - 128, id="apply_lookup_table", ), ], ) -def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): - engine = CompilerEngine() - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) - assert abs(engine.run(*args) - expected_result) / tab_size < 0.1 +def test_compile_and_run_tlu(mlir_input, args, expected_result): + engine = JITLambdaSupport.new() + compile_and_run(engine, mlir_input, args, expected_result) @pytest.mark.parametrize( @@ -343,8 +339,8 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): ], ) def test_compile_invalid(mlir_input): - engine = CompilerEngine() + engine = JITLambdaSupport.new() with pytest.raises( RuntimeError, match=r"cannot find the function for generate client parameters" ): - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) + engine.compile(mlir_input) diff --git a/compiler/tests/python/test_compiler_engine_parallel.py b/compiler/tests/python/test_compiler_engine_parallel.py index 362af64d5..66fe89188 100644 --- a/compiler/tests/python/test_compiler_engine_parallel.py +++ b/compiler/tests/python/test_compiler_engine_parallel.py @@ -3,10 +3,37 @@ import tempfile import pytest import numpy as np -from concrete.compiler import CompilerEngine + +from concrete.compiler.client_support import ClientSupport +from concrete.compiler.compilation_options import CompilationOptions +from concrete.compiler.jit_lambda_support import JITLambdaSupport +from concrete.compiler.key_set_cache import KeySetCache KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") +keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH) + + +def compile_and_run(engine, mlir_input, args, expected_result): + options = CompilationOptions.new("main") + options.set_auto_parallelize(True) + compilation_result = engine.compile(mlir_input, options) + # Client + client_parameters = engine.load_client_parameters(compilation_result) + key_set = ClientSupport.key_set(client_parameters, keyset_cache) + public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) + # Server + server_lambda = engine.load_server_lambda(compilation_result) + public_result = engine.server_call(server_lambda, public_arguments) + # Client + result = ClientSupport.decrypt_result(key_set, public_result) + # Check result + assert type(expected_result) == type(result) + if isinstance(expected_result, int): + assert result == expected_result + else: + assert np.all(result == expected_result) + @pytest.mark.parallel @pytest.mark.parametrize( @@ -42,14 +69,5 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") ], ) def test_compile_and_run_parallel(mlir_input, args, expected_result): - engine = CompilerEngine() - engine.compile_fhe( - mlir_input, - unsecure_key_set_cache_path=KEY_SET_CACHE_PATH, - auto_parallelize=True, - ) - if isinstance(expected_result, int): - assert engine.run(*args) == expected_result - else: - # numpy array - assert np.all(engine.run(*args) == expected_result) + engine = JITLambdaSupport.new() + compile_and_run(engine, mlir_input, args, expected_result)