mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test: update tests to the refactored API
This commit is contained in:
@@ -3,21 +3,20 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
|
||||
from concrete.compiler import JITLambdaSupport, LibraryLambdaSupport
|
||||
from concrete.compiler import ClientSupport
|
||||
from concrete.compiler import KeySetCache
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
|
||||
keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
def compile_and_run(engine, mlir_input, args, expected_result):
|
||||
compilation_result = engine.compile(mlir_input)
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
@@ -245,25 +244,25 @@ end_to_end_fixture = [
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_jit_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = JITCompilerSupport()
|
||||
engine = JITLambdaSupport.new()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
|
||||
engine = LibraryLambdaSupport.new("py_test_lib_compile_and_run")
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
|
||||
engine = LibraryLambdaSupport.new("test_lib_compile_reload_and_run")
|
||||
# Here don't save compilation result, reload
|
||||
engine.compile(mlir_input)
|
||||
compilation_result = engine.reload()
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
@@ -294,16 +293,15 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
engine = JITLambdaSupport.new()
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
|
||||
):
|
||||
engine.run(*args)
|
||||
compile_and_run(engine, mlir_input, args, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result, tab_size",
|
||||
"mlir_input, args, expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
"""
|
||||
@@ -315,15 +313,13 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
""",
|
||||
(73,),
|
||||
73,
|
||||
128,
|
||||
id="apply_lookup_table",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
def test_compile_and_run_tlu(mlir_input, args, expected_result):
|
||||
engine = JITLambdaSupport.new()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -343,8 +339,8 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
],
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = CompilerEngine()
|
||||
engine = JITLambdaSupport.new()
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"cannot find the function for generate client parameters"
|
||||
):
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
engine.compile(mlir_input)
|
||||
|
||||
@@ -3,10 +3,37 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
|
||||
from concrete.compiler.client_support import ClientSupport
|
||||
from concrete.compiler.compilation_options import CompilationOptions
|
||||
from concrete.compiler.jit_lambda_support import JITLambdaSupport
|
||||
from concrete.compiler.key_set_cache import KeySetCache
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
def compile_and_run(engine, mlir_input, args, expected_result):
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_auto_parallelize(True)
|
||||
compilation_result = engine.compile(mlir_input, options)
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
# Check result
|
||||
assert type(expected_result) == type(result)
|
||||
if isinstance(expected_result, int):
|
||||
assert result == expected_result
|
||||
else:
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
@@ -42,14 +69,5 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_parallel(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input,
|
||||
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
|
||||
auto_parallelize=True,
|
||||
)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
# numpy array
|
||||
assert np.all(engine.run(*args) == expected_result)
|
||||
engine = JITLambdaSupport.new()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
Reference in New Issue
Block a user