test: update tests to the refactored API

This commit is contained in:
youben11
2022-03-31 12:10:55 +01:00
committed by Ayoub Benaissa
parent 999ab4e5ea
commit edd10c1436
2 changed files with 45 additions and 31 deletions

View File

@@ -3,21 +3,20 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler import CompilerEngine
from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
from concrete.compiler import JITLambdaSupport, LibraryLambdaSupport
from concrete.compiler import ClientSupport
from concrete.compiler import KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH)
def compile_and_run(engine, mlir_input, args, expected_result):
compilation_result = engine.compile(mlir_input)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
@@ -245,25 +244,25 @@ end_to_end_fixture = [
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_jit_compile_and_run(mlir_input, args, expected_result):
engine = JITCompilerSupport()
engine = JITLambdaSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
engine = LibraryLambdaSupport.new("py_test_lib_compile_and_run")
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
engine = LibraryLambdaSupport.new("test_lib_compile_reload_and_run")
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
@@ -294,16 +293,15 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
],
)
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
engine = JITLambdaSupport.new()
with pytest.raises(
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
):
engine.run(*args)
compile_and_run(engine, mlir_input, args, None)
@pytest.mark.parametrize(
"mlir_input, args, expected_result, tab_size",
"mlir_input, args, expected_result",
[
pytest.param(
"""
@@ -315,15 +313,13 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
""",
(73,),
73,
128,
id="apply_lookup_table",
),
],
)
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
engine = CompilerEngine()
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
def test_compile_and_run_tlu(mlir_input, args, expected_result):
engine = JITLambdaSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
@@ -343,8 +339,8 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
],
)
def test_compile_invalid(mlir_input):
engine = CompilerEngine()
engine = JITLambdaSupport.new()
with pytest.raises(
RuntimeError, match=r"cannot find the function for generate client parameters"
):
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
engine.compile(mlir_input)

View File

@@ -3,10 +3,37 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler import CompilerEngine
from concrete.compiler.client_support import ClientSupport
from concrete.compiler.compilation_options import CompilationOptions
from concrete.compiler.jit_lambda_support import JITLambdaSupport
from concrete.compiler.key_set_cache import KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH)
def compile_and_run(engine, mlir_input, args, expected_result):
options = CompilationOptions.new("main")
options.set_auto_parallelize(True)
compilation_result = engine.compile(mlir_input, options)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
# Check result
assert type(expected_result) == type(result)
if isinstance(expected_result, int):
assert result == expected_result
else:
assert np.all(result == expected_result)
@pytest.mark.parallel
@pytest.mark.parametrize(
@@ -42,14 +69,5 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
],
)
def test_compile_and_run_parallel(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input,
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
auto_parallelize=True,
)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
# numpy array
assert np.all(engine.run(*args) == expected_result)
engine = JITLambdaSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)