test: update tests to the refactored API

This commit is contained in:
youben11
2022-03-31 12:10:55 +01:00
committed by Ayoub Benaissa
parent 999ab4e5ea
commit edd10c1436
2 changed files with 45 additions and 31 deletions

View File

@@ -3,10 +3,37 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler import CompilerEngine
from concrete.compiler.client_support import ClientSupport
from concrete.compiler.compilation_options import CompilationOptions
from concrete.compiler.jit_lambda_support import JITLambdaSupport
from concrete.compiler.key_set_cache import KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH)
def compile_and_run(engine, mlir_input, args, expected_result):
options = CompilationOptions.new("main")
options.set_auto_parallelize(True)
compilation_result = engine.compile(mlir_input, options)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
# Check result
assert type(expected_result) == type(result)
if isinstance(expected_result, int):
assert result == expected_result
else:
assert np.all(result == expected_result)
@pytest.mark.parallel
@pytest.mark.parametrize(
@@ -42,14 +69,5 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
],
)
def test_compile_and_run_parallel(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input,
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
auto_parallelize=True,
)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
# numpy array
assert np.all(engine.run(*args) == expected_result)
engine = JITLambdaSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)