chore: format python code with black

This commit is contained in:
youben11
2022-03-29 14:16:53 +01:00
committed by Ayoub Benaissa
parent 17c72f2e2d
commit 51308058c1
10 changed files with 133 additions and 95 deletions

View File

@@ -251,6 +251,11 @@ generate_conv_op:
python -m mlir.dialects.linalg.opdsl.dump_oplib ops.core_named_ops > ops/LinalgNamedStructuredOps.yaml
$(BUILD_DIR)/bin/mlir-linalg-ods-yaml-gen ops/LinalgNamedStructuredOps.yaml --o-impl=ops/LinalgOps.cpp --o-ods-decl=ops/LinalgNamedStructuredOps.yamlgen.td
check_python_format:
black --check tests/python/ lib/Bindings/Python/concrete/
python_format:
black tests/python/ lib/Bindings/Python/concrete/
.PHONY: build-initialized \
build-end-to-end-jit \
@@ -277,4 +282,6 @@ generate_conv_op:
uninstall\
install_runtime_lib \
uninstall_runtime_lib \
generate_conv_op
generate_conv_op \
python_format \
check_python_format

View File

@@ -1 +1 @@
__import__('pkg_resources').declare_namespace(__name__)
__import__("pkg_resources").declare_namespace(__name__)

View File

@@ -7,7 +7,9 @@ import os
import atexit
from typing import List, Union
from mlir._mlir_libs._concretelang._compiler import terminate_parallelization as _terminate_parallelization
from mlir._mlir_libs._concretelang._compiler import (
terminate_parallelization as _terminate_parallelization,
)
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
@@ -24,11 +26,15 @@ from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArg
from mlir._mlir_libs._concretelang._compiler import CompilationOptions
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport
from mlir._mlir_libs._concretelang._compiler import (
JITLambdaSupport as _JITLambdaSupport,
)
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
from mlir._mlir_libs._concretelang._compiler import JITLambda
from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport as _LibraryLambdaSupport
from mlir._mlir_libs._concretelang._compiler import (
LibraryLambdaSupport as _LibraryLambdaSupport,
)
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
import numpy as np
@@ -63,8 +69,7 @@ def _lookup_runtime_lib() -> str:
for filename in os.listdir(libs_path)
if filename.startswith("libConcretelangRuntime")
]
assert len(
runtime_library_paths) == 1, "should be one and only one runtime library"
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
return os.path.join(libs_path, runtime_library_paths[0])
@@ -123,21 +128,19 @@ class CompilerEngine:
)
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
if not isinstance(unsecure_key_set_cache_path, str):
raise TypeError(
"unsecure_key_set_cache_path must be a str"
)
raise TypeError("unsecure_key_set_cache_path must be a str")
options = CompilationOptions(func_name)
options.auto_parallelize(auto_parallelize)
options.loop_parallelize(loop_parallelize)
options.dataflow_parallelize(df_parallelize)
self._compilation_result = self._engine.compile(mlir_str, options)
self._client_parameters = self._engine.load_client_parameters(
self._compilation_result)
self._compilation_result
)
keyset_cache = None
if not unsecure_key_set_cache_path is None:
keyset_cache = KeySetCache(unsecure_key_set_cache_path)
self._key_set = ClientSupport.key_set(
self._client_parameters, keyset_cache)
self._key_set = ClientSupport.key_set(self._client_parameters, keyset_cache)
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
"""Run the compiled code.
@@ -156,19 +159,20 @@ class CompilerEngine:
if self._compilation_result is None:
raise RuntimeError("need to compile an MLIR code first")
# Client
public_arguments = ClientSupport.encrypt_arguments(self._client_parameters,
self._key_set, args)
public_arguments = ClientSupport.encrypt_arguments(
self._client_parameters, self._key_set, args
)
# Server
server_lambda = self._engine.load_server_lambda(
self._compilation_result)
public_result = self._engine.server_call(
server_lambda, public_arguments)
server_lambda = self._engine.load_server_lambda(self._compilation_result)
public_result = self._engine.server_call(server_lambda, public_arguments)
# Client
return ClientSupport.decrypt_result(self._key_set, public_result)
class ClientSupport:
def key_set(client_parameters: ClientParameters, cache: KeySetCache = None) -> KeySet:
def key_set(
client_parameters: ClientParameters, cache: KeySetCache = None
) -> KeySet:
"""Generates a key set according to the given client parameters.
If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache
@@ -181,7 +185,11 @@ class ClientSupport:
"""
return _ClientSupport.key_set(client_parameters, cache)
def encrypt_arguments(client_parameters: ClientParameters, key_set: KeySet, args: List[Union[int, np.ndarray]]) -> PublicArguments:
def encrypt_arguments(
client_parameters: ClientParameters,
key_set: KeySet,
args: List[Union[int, np.ndarray]],
) -> PublicArguments:
"""Export clear arguments to public arguments.
For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object.
@@ -193,10 +201,15 @@ class ClientSupport:
PublicArguments: the public arguments
"""
execution_arguments = [
ClientSupport._create_execution_argument(arg) for arg in args]
return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments)
ClientSupport._create_execution_argument(arg) for arg in args
]
return _ClientSupport.encrypt_arguments(
client_parameters, key_set, execution_arguments
)
def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]:
def decrypt_result(
key_set: KeySet, public_result: PublicResult
) -> Union[int, np.ndarray]:
"""Decrypt a public result thanks the given key set.
Args:
@@ -205,7 +218,7 @@ class ClientSupport:
Returns:
int or numpy.array: The result of decryption.
"""
"""
lambda_arg = _ClientSupport.decrypt_result(key_set, public_result)
if lambda_arg.is_scalar():
return lambda_arg.get_scalar()
@@ -230,7 +243,8 @@ class ClientSupport:
"""
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError(
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
)
if isinstance(value, ACCEPTED_INTS):
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
raise TypeError(
@@ -242,8 +256,7 @@ class ClientSupport:
if value.shape == ():
return _LambdaArgument.from_scalar(value)
if value.dtype not in ACCEPTED_NUMPY_UINTS:
raise TypeError(
"numpy.array must be of dtype uint{8,16,32,64}")
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
@@ -258,7 +271,11 @@ class JITCompilerSupport:
)
self._support = _JITLambdaSupport(runtime_lib_path)
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult:
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions("main"),
) -> JitCompilationResult:
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
Args:
@@ -272,7 +289,9 @@ class JITCompilerSupport:
raise TypeError("mlir_program must be an `str`")
return self._support.compile(mlir_program, options)
def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters:
def load_client_parameters(
self, compilation_result: JitCompilationResult
) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
return self._support.load_client_parameters(compilation_result)
@@ -298,7 +317,11 @@ class LibraryCompilerSupport:
self._library_path = outputPath
self._support = _LibraryLambdaSupport(outputPath)
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> LibraryCompilationResult:
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions("main"),
) -> LibraryCompilationResult:
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
Args:
@@ -327,19 +350,24 @@ class LibraryCompilerSupport:
raise TypeError("func_name must be an `str`")
return LibraryCompilationResult(self._library_path, func_name)
def load_client_parameters(self, compilation_result: LibraryCompilationResult) -> ClientParameters:
def load_client_parameters(
self, compilation_result: LibraryCompilationResult
) -> ClientParameters:
"""Load the client parameters from the JIT compilation result"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError(
"compilation_result must be an `LibraryCompilationResult`")
raise TypeError("compilation_result must be an `LibraryCompilationResult`")
return self._support.load_client_parameters(compilation_result)
def load_server_lambda(self, compilation_result: LibraryCompilationResult) -> LibraryLambda:
def load_server_lambda(
self, compilation_result: LibraryCompilationResult
) -> LibraryLambda:
"""Load the server lambda from the JIT compilation result"""
return self._support.load_server_lambda(compilation_result)
def server_call(self, server_lambda: LibraryLambda, public_arguments: PublicArguments) -> PublicResult:
def server_call(
self, server_lambda: LibraryLambda, public_arguments: PublicArguments
) -> PublicResult:
"""Call the server lambda with public_arguments
Args:

View File

@@ -1,3 +1,2 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.

View File

@@ -2,4 +2,12 @@
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
# We need this helpers from the mlir bindings, they are used in the generated files
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context, get_op_result_or_value, get_op_results_or_values
from mlir.dialects._ods_common import (
_cext,
segmented_accessor,
equally_sized_accessor,
extend_opview_class,
get_default_loc_context,
get_op_result_or_value,
get_op_results_or_values,
)

View File

@@ -8,7 +8,7 @@ from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
from concrete.compiler import ClientSupport
from concrete.compiler import KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
@@ -18,8 +18,7 @@ def compile_and_run(engine, mlir_input, args, expected_result):
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
@@ -199,8 +198,7 @@ end_to_end_fixture = [
""",
(
np.array(
[[31, 6, 12, 9], [31, 6, 12, 9], [
31, 6, 12, 9], [31, 6, 12, 9]],
[[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]],
dtype=np.uint8,
),
np.array(
@@ -245,28 +243,19 @@ end_to_end_fixture = [
]
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_jit_compile_and_run(mlir_input, args, expected_result):
engine = JITCompilerSupport()
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
end_to_end_fixture
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
# Here don't save compilation result, reload
@@ -275,8 +264,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
@@ -290,7 +278,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
assert np.all(result == expected_result)
@ pytest.mark.parametrize(
@pytest.mark.parametrize(
"mlir_input, args",
[
pytest.param(
@@ -307,13 +295,14 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
)
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"):
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
):
engine.run(*args)
@ pytest.mark.parametrize(
@pytest.mark.parametrize(
"mlir_input, args, expected_result, tab_size",
[
pytest.param(
@@ -333,12 +322,11 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
)
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
engine = CompilerEngine()
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
@ pytest.mark.parametrize(
@pytest.mark.parametrize(
"mlir_input",
[
pytest.param(
@@ -356,6 +344,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
)
def test_compile_invalid(mlir_input):
engine = CompilerEngine()
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
with pytest.raises(
RuntimeError, match=r"cannot find the function for generate client parameters"
):
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)

View File

@@ -7,6 +7,7 @@ from concrete.compiler import CompilerEngine
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
@pytest.mark.parallel
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
@@ -45,7 +46,7 @@ def test_compile_and_run_parallel(mlir_input, args, expected_result):
engine.compile_fhe(
mlir_input,
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
auto_parallelize=True
auto_parallelize=True,
)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result

View File

@@ -8,26 +8,27 @@ from test_compiler_file_output.utils import assert_exists, content, remove, run
TEST_PATH = os.path.dirname(__file__)
CCOMPILER = 'cc'
CONCRETECOMPILER = 'concretecompiler'
CCOMPILER = "cc"
CONCRETECOMPILER = "concretecompiler"
SOURCE_1 = f'{TEST_PATH}/return_13.ir'
SOURCE_2 = f'{TEST_PATH}/return_0.ir'
SOURCE_C_1 = f'{TEST_PATH}/main_return_13.c'
SOURCE_C_2 = f'{TEST_PATH}/main_return_0.c'
OUTPUT = f'{TEST_PATH}/output.mlir'
LIB = f'{TEST_PATH}/outlib'
LIB_STATIC = LIB + '.a'
DYNAMIC_LIB_EXT = '.dylib' if sys.platform == 'darwin' else '.so'
SOURCE_1 = f"{TEST_PATH}/return_13.ir"
SOURCE_2 = f"{TEST_PATH}/return_0.ir"
SOURCE_C_1 = f"{TEST_PATH}/main_return_13.c"
SOURCE_C_2 = f"{TEST_PATH}/main_return_0.c"
OUTPUT = f"{TEST_PATH}/output.mlir"
LIB = f"{TEST_PATH}/outlib"
LIB_STATIC = LIB + ".a"
DYNAMIC_LIB_EXT = ".dylib" if sys.platform == "darwin" else ".so"
LIB_DYNAMIC = LIB + DYNAMIC_LIB_EXT
LIBS = (LIB_STATIC, LIB_DYNAMIC)
assert_exists(SOURCE_1, SOURCE_2, SOURCE_C_1, SOURCE_C_2)
def test_roundtrip():
remove(OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, '--action=roundtrip', '-o', OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, "--action=roundtrip", "-o", OUTPUT)
assert_exists(OUTPUT)
assert content(SOURCE_1) == content(OUTPUT)
@@ -38,7 +39,7 @@ def test_roundtrip():
def test_roundtrip_many():
remove(OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=roundtrip', '-o', OUTPUT)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=roundtrip", "-o", OUTPUT)
assert_exists(OUTPUT)
assert f"{content(SOURCE_1)}{content(SOURCE_2)}" == content(OUTPUT)
@@ -49,35 +50,36 @@ def test_roundtrip_many():
def test_compile_library():
remove(LIBS)
run(CONCRETECOMPILER, SOURCE_1, '--action=compile', '-o', LIB)
run(CONCRETECOMPILER, SOURCE_1, "--action=compile", "-o", LIB)
assert_exists(LIBS)
EXE = './main.exe'
EXE = "./main.exe"
remove(EXE)
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_STATIC)
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_STATIC)
result = subprocess.run([EXE], capture_output=True)
assert 13 == result.returncode
remove(EXE)
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_DYNAMIC)
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_DYNAMIC)
result = subprocess.run([EXE], capture_output=True)
assert 13 == result.returncode
remove(LIBS, EXE)
def test_compile_many_library():
remove(LIBS)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=compile', '-o', LIB)
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=compile", "-o", LIB)
assert_exists(LIBS)
EXE = './main.exe'
EXE = "./main.exe"
remove(EXE)
run(CCOMPILER, '-o', EXE, SOURCE_C_2, LIB_DYNAMIC)
run(CCOMPILER, "-o", EXE, SOURCE_C_2, LIB_DYNAMIC)
result = subprocess.run([EXE], capture_output=True)
assert 0 == result.returncode

View File

@@ -1,6 +1,7 @@
import os
import subprocess
def on_paths(func, *paths):
for path in paths:
try:
@@ -11,27 +12,32 @@ def on_paths(func, *paths):
except FileNotFoundError:
pass
def assert_exists(*paths):
def func(path):
if not os.path.exists(path):
dirpath = os.path.dirname(path)
if os.path.exists(dirpath):
msg = f'{path} is not in {dirpath}'
msg = f"{path} is not in {dirpath}"
else:
msg = f'{dirpath} does not exist for {path}'
msg = f"{dirpath} does not exist for {path}"
assert False, msg
on_paths(func, *paths)
def remove(*paths):
on_paths(os.remove, *paths)
def content(path):
with open(path) as f:
return f.read()
def run(*cmd):
result = subprocess.run(cmd, capture_output=True)
if result.returncode != 0:
print(result.stderr)
assert result.returncode == 0, ' '.join(cmd)
return str(result.stdout, encoding='utf-8')
assert result.returncode == 0, " ".join(cmd)
return str(result.stdout, encoding="utf-8")

View File

@@ -18,9 +18,7 @@ def test_eint_tensor(shape):
register_dialects(ctx)
eint = fhe.EncryptedIntegerType.get(ctx, 3)
tensor = RankedTensorType.get(shape, eint)
assert (
tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
)
assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
@pytest.mark.parametrize("width", [0])