mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
test: update tests to the refactored API
This commit is contained in:
@@ -3,10 +3,37 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
|
||||
from concrete.compiler.client_support import ClientSupport
|
||||
from concrete.compiler.compilation_options import CompilationOptions
|
||||
from concrete.compiler.jit_lambda_support import JITLambdaSupport
|
||||
from concrete.compiler.key_set_cache import KeySetCache
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
keyset_cache = KeySetCache.new(KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
def compile_and_run(engine, mlir_input, args, expected_result):
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_auto_parallelize(True)
|
||||
compilation_result = engine.compile(mlir_input, options)
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
# Check result
|
||||
assert type(expected_result) == type(result)
|
||||
if isinstance(expected_result, int):
|
||||
assert result == expected_result
|
||||
else:
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
@@ -42,14 +69,5 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_parallel(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input,
|
||||
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
|
||||
auto_parallelize=True,
|
||||
)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
# numpy array
|
||||
assert np.all(engine.run(*args) == expected_result)
|
||||
engine = JITLambdaSupport.new()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
Reference in New Issue
Block a user