mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
chore: format python code with black
This commit is contained in:
@@ -8,7 +8,7 @@ from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
|
||||
from concrete.compiler import ClientSupport
|
||||
from concrete.compiler import KeySetCache
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
|
||||
|
||||
@@ -18,8 +18,7 @@ def compile_and_run(engine, mlir_input, args, expected_result):
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
@@ -199,8 +198,7 @@ end_to_end_fixture = [
|
||||
""",
|
||||
(
|
||||
np.array(
|
||||
[[31, 6, 12, 9], [31, 6, 12, 9], [
|
||||
31, 6, 12, 9], [31, 6, 12, 9]],
|
||||
[[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.array(
|
||||
@@ -245,28 +243,19 @@ end_to_end_fixture = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_jit_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = JITCompilerSupport()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
|
||||
# Here don't save compilation result, reload
|
||||
@@ -275,8 +264,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
@@ -290,7 +278,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -307,13 +295,14 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
)
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"):
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
|
||||
):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result, tab_size",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -333,12 +322,11 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
)
|
||||
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -356,6 +344,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = CompilerEngine()
|
||||
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"cannot find the function for generate client parameters"
|
||||
):
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
|
||||
@@ -7,6 +7,7 @@ from concrete.compiler import CompilerEngine
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
@@ -45,7 +46,7 @@ def test_compile_and_run_parallel(mlir_input, args, expected_result):
|
||||
engine.compile_fhe(
|
||||
mlir_input,
|
||||
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
|
||||
auto_parallelize=True
|
||||
auto_parallelize=True,
|
||||
)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
|
||||
@@ -8,26 +8,27 @@ from test_compiler_file_output.utils import assert_exists, content, remove, run
|
||||
|
||||
TEST_PATH = os.path.dirname(__file__)
|
||||
|
||||
CCOMPILER = 'cc'
|
||||
CONCRETECOMPILER = 'concretecompiler'
|
||||
CCOMPILER = "cc"
|
||||
CONCRETECOMPILER = "concretecompiler"
|
||||
|
||||
SOURCE_1 = f'{TEST_PATH}/return_13.ir'
|
||||
SOURCE_2 = f'{TEST_PATH}/return_0.ir'
|
||||
SOURCE_C_1 = f'{TEST_PATH}/main_return_13.c'
|
||||
SOURCE_C_2 = f'{TEST_PATH}/main_return_0.c'
|
||||
OUTPUT = f'{TEST_PATH}/output.mlir'
|
||||
LIB = f'{TEST_PATH}/outlib'
|
||||
LIB_STATIC = LIB + '.a'
|
||||
DYNAMIC_LIB_EXT = '.dylib' if sys.platform == 'darwin' else '.so'
|
||||
SOURCE_1 = f"{TEST_PATH}/return_13.ir"
|
||||
SOURCE_2 = f"{TEST_PATH}/return_0.ir"
|
||||
SOURCE_C_1 = f"{TEST_PATH}/main_return_13.c"
|
||||
SOURCE_C_2 = f"{TEST_PATH}/main_return_0.c"
|
||||
OUTPUT = f"{TEST_PATH}/output.mlir"
|
||||
LIB = f"{TEST_PATH}/outlib"
|
||||
LIB_STATIC = LIB + ".a"
|
||||
DYNAMIC_LIB_EXT = ".dylib" if sys.platform == "darwin" else ".so"
|
||||
LIB_DYNAMIC = LIB + DYNAMIC_LIB_EXT
|
||||
LIBS = (LIB_STATIC, LIB_DYNAMIC)
|
||||
|
||||
assert_exists(SOURCE_1, SOURCE_2, SOURCE_C_1, SOURCE_C_2)
|
||||
|
||||
|
||||
def test_roundtrip():
|
||||
remove(OUTPUT)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, '--action=roundtrip', '-o', OUTPUT)
|
||||
run(CONCRETECOMPILER, SOURCE_1, "--action=roundtrip", "-o", OUTPUT)
|
||||
|
||||
assert_exists(OUTPUT)
|
||||
assert content(SOURCE_1) == content(OUTPUT)
|
||||
@@ -38,7 +39,7 @@ def test_roundtrip():
|
||||
def test_roundtrip_many():
|
||||
remove(OUTPUT)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=roundtrip', '-o', OUTPUT)
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=roundtrip", "-o", OUTPUT)
|
||||
|
||||
assert_exists(OUTPUT)
|
||||
assert f"{content(SOURCE_1)}{content(SOURCE_2)}" == content(OUTPUT)
|
||||
@@ -49,35 +50,36 @@ def test_roundtrip_many():
|
||||
def test_compile_library():
|
||||
remove(LIBS)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, '--action=compile', '-o', LIB)
|
||||
run(CONCRETECOMPILER, SOURCE_1, "--action=compile", "-o", LIB)
|
||||
|
||||
assert_exists(LIBS)
|
||||
|
||||
EXE = './main.exe'
|
||||
EXE = "./main.exe"
|
||||
remove(EXE)
|
||||
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_STATIC)
|
||||
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_STATIC)
|
||||
|
||||
result = subprocess.run([EXE], capture_output=True)
|
||||
assert 13 == result.returncode
|
||||
|
||||
remove(EXE)
|
||||
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_DYNAMIC)
|
||||
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_DYNAMIC)
|
||||
|
||||
result = subprocess.run([EXE], capture_output=True)
|
||||
assert 13 == result.returncode
|
||||
|
||||
remove(LIBS, EXE)
|
||||
|
||||
|
||||
def test_compile_many_library():
|
||||
remove(LIBS)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=compile', '-o', LIB)
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=compile", "-o", LIB)
|
||||
|
||||
assert_exists(LIBS)
|
||||
|
||||
EXE = './main.exe'
|
||||
EXE = "./main.exe"
|
||||
remove(EXE)
|
||||
run(CCOMPILER, '-o', EXE, SOURCE_C_2, LIB_DYNAMIC)
|
||||
run(CCOMPILER, "-o", EXE, SOURCE_C_2, LIB_DYNAMIC)
|
||||
|
||||
result = subprocess.run([EXE], capture_output=True)
|
||||
assert 0 == result.returncode
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def on_paths(func, *paths):
|
||||
for path in paths:
|
||||
try:
|
||||
@@ -11,27 +12,32 @@ def on_paths(func, *paths):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def assert_exists(*paths):
|
||||
def func(path):
|
||||
if not os.path.exists(path):
|
||||
dirpath = os.path.dirname(path)
|
||||
if os.path.exists(dirpath):
|
||||
msg = f'{path} is not in {dirpath}'
|
||||
msg = f"{path} is not in {dirpath}"
|
||||
else:
|
||||
msg = f'{dirpath} does not exist for {path}'
|
||||
msg = f"{dirpath} does not exist for {path}"
|
||||
assert False, msg
|
||||
|
||||
on_paths(func, *paths)
|
||||
|
||||
|
||||
def remove(*paths):
|
||||
on_paths(os.remove, *paths)
|
||||
|
||||
|
||||
def content(path):
|
||||
with open(path) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def run(*cmd):
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
if result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert result.returncode == 0, ' '.join(cmd)
|
||||
return str(result.stdout, encoding='utf-8')
|
||||
assert result.returncode == 0, " ".join(cmd)
|
||||
return str(result.stdout, encoding="utf-8")
|
||||
|
||||
@@ -18,9 +18,7 @@ def test_eint_tensor(shape):
|
||||
register_dialects(ctx)
|
||||
eint = fhe.EncryptedIntegerType.get(ctx, 3)
|
||||
tensor = RankedTensorType.get(shape, eint)
|
||||
assert (
|
||||
tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
|
||||
)
|
||||
assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width", [0])
|
||||
|
||||
Reference in New Issue
Block a user