chore: format python code with black

This commit is contained in:
youben11
2022-03-29 14:16:53 +01:00
committed by Ayoub Benaissa
parent 17c72f2e2d
commit 51308058c1
10 changed files with 133 additions and 95 deletions

View File

@@ -8,7 +8,7 @@ from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
from concrete.compiler import ClientSupport
from concrete.compiler import KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
@@ -18,8 +18,7 @@ def compile_and_run(engine, mlir_input, args, expected_result):
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
@@ -199,8 +198,7 @@ end_to_end_fixture = [
""",
(
np.array(
[[31, 6, 12, 9], [31, 6, 12, 9], [
31, 6, 12, 9], [31, 6, 12, 9]],
[[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]],
dtype=np.uint8,
),
np.array(
@@ -245,28 +243,19 @@ end_to_end_fixture = [
]
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_jit_compile_and_run(mlir_input, args, expected_result):
engine = JITCompilerSupport()
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
# Here don't save compilation result, reload
@@ -275,8 +264,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
@@ -290,7 +278,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
assert np.all(result == expected_result)
@ pytest.mark.parametrize(
@pytest.mark.parametrize(
"mlir_input, args",
[
pytest.param(
@@ -307,13 +295,14 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
)
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"):
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
):
engine.run(*args)
@ pytest.mark.parametrize(
@pytest.mark.parametrize(
"mlir_input, args, expected_result, tab_size",
[
pytest.param(
@@ -333,12 +322,11 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
)
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
@ pytest.mark.parametrize(
@pytest.mark.parametrize(
"mlir_input",
[
pytest.param(
@@ -356,6 +344,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
)
def test_compile_invalid(mlir_input):
engine = CompilerEngine()
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(
RuntimeError, match=r"cannot find the function for generate client parameters"
):
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)

View File

@@ -7,6 +7,7 @@ from concrete.compiler import CompilerEngine
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
@pytest.mark.parallel
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
@@ -45,7 +46,7 @@ def test_compile_and_run_parallel(mlir_input, args, expected_result):
engine.compile_fhe(
mlir_input,
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
auto_parallelize=True
auto_parallelize=True,
)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result

View File

@@ -8,26 +8,27 @@ from test_compiler_file_output.utils import assert_exists, content, remove, run
TEST_PATH = os.path.dirname(__file__)
CCOMPILER = 'cc'
CONCRETECOMPILER = 'concretecompiler'
CCOMPILER = "cc"
CONCRETECOMPILER = "concretecompiler"
SOURCE_1 = f'{TEST_PATH}/return_13.ir'
SOURCE_2 = f'{TEST_PATH}/return_0.ir'
SOURCE_C_1 = f'{TEST_PATH}/main_return_13.c'
SOURCE_C_2 = f'{TEST_PATH}/main_return_0.c'
OUTPUT = f'{TEST_PATH}/output.mlir'
LIB = f'{TEST_PATH}/outlib'
LIB_STATIC = LIB + '.a'
DYNAMIC_LIB_EXT = '.dylib' if sys.platform == 'darwin' else '.so'
SOURCE_1 = f"{TEST_PATH}/return_13.ir"
SOURCE_2 = f"{TEST_PATH}/return_0.ir"
SOURCE_C_1 = f"{TEST_PATH}/main_return_13.c"
SOURCE_C_2 = f"{TEST_PATH}/main_return_0.c"
OUTPUT = f"{TEST_PATH}/output.mlir"
LIB = f"{TEST_PATH}/outlib"
LIB_STATIC = LIB + ".a"
DYNAMIC_LIB_EXT = ".dylib" if sys.platform == "darwin" else ".so"
LIB_DYNAMIC = LIB + DYNAMIC_LIB_EXT
LIBS = (LIB_STATIC, LIB_DYNAMIC)
assert_exists(SOURCE_1, SOURCE_2, SOURCE_C_1, SOURCE_C_2)
def test_roundtrip():
remove(OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, '--action=roundtrip', '-o', OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, "--action=roundtrip", "-o", OUTPUT)
assert_exists(OUTPUT)
assert content(SOURCE_1) == content(OUTPUT)
@@ -38,7 +39,7 @@ def test_roundtrip():
def test_roundtrip_many():
remove(OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=roundtrip', '-o', OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=roundtrip", "-o", OUTPUT)
assert_exists(OUTPUT)
assert f"{content(SOURCE_1)}{content(SOURCE_2)}" == content(OUTPUT)
@@ -49,35 +50,36 @@ def test_roundtrip_many():
def test_compile_library():
remove(LIBS)
run(CONCRETECOMPILER, SOURCE_1, '--action=compile', '-o', LIB)
run(CONCRETECOMPILER, SOURCE_1, "--action=compile", "-o", LIB)
assert_exists(LIBS)
EXE = './main.exe'
EXE = "./main.exe"
remove(EXE)
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_STATIC)
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_STATIC)
result = subprocess.run([EXE], capture_output=True)
assert 13 == result.returncode
remove(EXE)
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_DYNAMIC)
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_DYNAMIC)
result = subprocess.run([EXE], capture_output=True)
assert 13 == result.returncode
remove(LIBS, EXE)
def test_compile_many_library():
remove(LIBS)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=compile', '-o', LIB)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=compile", "-o", LIB)
assert_exists(LIBS)
EXE = './main.exe'
EXE = "./main.exe"
remove(EXE)
run(CCOMPILER, '-o', EXE, SOURCE_C_2, LIB_DYNAMIC)
run(CCOMPILER, "-o", EXE, SOURCE_C_2, LIB_DYNAMIC)
result = subprocess.run([EXE], capture_output=True)
assert 0 == result.returncode

View File

@@ -1,6 +1,7 @@
import os
import subprocess
def on_paths(func, *paths):
for path in paths:
try:
@@ -11,27 +12,32 @@ def on_paths(func, *paths):
except FileNotFoundError:
pass
def assert_exists(*paths):
def func(path):
if not os.path.exists(path):
dirpath = os.path.dirname(path)
if os.path.exists(dirpath):
msg = f'{path} is not in {dirpath}'
msg = f"{path} is not in {dirpath}"
else:
msg = f'{dirpath} does not exist for {path}'
msg = f"{dirpath} does not exist for {path}"
assert False, msg
on_paths(func, *paths)
def remove(*paths):
on_paths(os.remove, *paths)
def content(path):
with open(path) as f:
return f.read()
def run(*cmd):
result = subprocess.run(cmd, capture_output=True)
if result.returncode != 0:
print(result.stderr)
assert result.returncode == 0, ' '.join(cmd)
return str(result.stdout, encoding='utf-8')
assert result.returncode == 0, " ".join(cmd)
return str(result.stdout, encoding="utf-8")

View File

@@ -18,9 +18,7 @@ def test_eint_tensor(shape):
register_dialects(ctx)
eint = fhe.EncryptedIntegerType.get(ctx, 3)
tensor = RankedTensorType.get(shape, eint)
assert (
tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
)
assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
@pytest.mark.parametrize("width", [0])