mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
chore: format python code with black
This commit is contained in:
@@ -251,6 +251,11 @@ generate_conv_op:
|
||||
python -m mlir.dialects.linalg.opdsl.dump_oplib ops.core_named_ops > ops/LinalgNamedStructuredOps.yaml
|
||||
$(BUILD_DIR)/bin/mlir-linalg-ods-yaml-gen ops/LinalgNamedStructuredOps.yaml --o-impl=ops/LinalgOps.cpp --o-ods-decl=ops/LinalgNamedStructuredOps.yamlgen.td
|
||||
|
||||
check_python_format:
|
||||
black --check tests/python/ lib/Bindings/Python/concrete/
|
||||
|
||||
python_format:
|
||||
black tests/python/ lib/Bindings/Python/concrete/
|
||||
|
||||
.PHONY: build-initialized \
|
||||
build-end-to-end-jit \
|
||||
@@ -277,4 +282,6 @@ generate_conv_op:
|
||||
uninstall\
|
||||
install_runtime_lib \
|
||||
uninstall_runtime_lib \
|
||||
generate_conv_op
|
||||
generate_conv_op \
|
||||
python_format \
|
||||
check_python_format
|
||||
|
||||
@@ -1 +1 @@
|
||||
__import__('pkg_resources').declare_namespace(__name__)
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
|
||||
@@ -7,7 +7,9 @@ import os
|
||||
import atexit
|
||||
from typing import List, Union
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import terminate_parallelization as _terminate_parallelization
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
terminate_parallelization as _terminate_parallelization,
|
||||
)
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
|
||||
@@ -24,11 +26,15 @@ from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArg
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import CompilationOptions
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITLambdaSupport as _JITLambdaSupport,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import JitCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import JITLambda
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport as _LibraryLambdaSupport
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LibraryLambdaSupport as _LibraryLambdaSupport,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult
|
||||
from mlir._mlir_libs._concretelang._compiler import LibraryLambda
|
||||
import numpy as np
|
||||
@@ -63,8 +69,7 @@ def _lookup_runtime_lib() -> str:
|
||||
for filename in os.listdir(libs_path)
|
||||
if filename.startswith("libConcretelangRuntime")
|
||||
]
|
||||
assert len(
|
||||
runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
return os.path.join(libs_path, runtime_library_paths[0])
|
||||
|
||||
|
||||
@@ -123,21 +128,19 @@ class CompilerEngine:
|
||||
)
|
||||
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
|
||||
if not isinstance(unsecure_key_set_cache_path, str):
|
||||
raise TypeError(
|
||||
"unsecure_key_set_cache_path must be a str"
|
||||
)
|
||||
raise TypeError("unsecure_key_set_cache_path must be a str")
|
||||
options = CompilationOptions(func_name)
|
||||
options.auto_parallelize(auto_parallelize)
|
||||
options.loop_parallelize(loop_parallelize)
|
||||
options.dataflow_parallelize(df_parallelize)
|
||||
self._compilation_result = self._engine.compile(mlir_str, options)
|
||||
self._client_parameters = self._engine.load_client_parameters(
|
||||
self._compilation_result)
|
||||
self._compilation_result
|
||||
)
|
||||
keyset_cache = None
|
||||
if not unsecure_key_set_cache_path is None:
|
||||
keyset_cache = KeySetCache(unsecure_key_set_cache_path)
|
||||
self._key_set = ClientSupport.key_set(
|
||||
self._client_parameters, keyset_cache)
|
||||
self._key_set = ClientSupport.key_set(self._client_parameters, keyset_cache)
|
||||
|
||||
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
|
||||
"""Run the compiled code.
|
||||
@@ -156,19 +159,20 @@ class CompilerEngine:
|
||||
if self._compilation_result is None:
|
||||
raise RuntimeError("need to compile an MLIR code first")
|
||||
# Client
|
||||
public_arguments = ClientSupport.encrypt_arguments(self._client_parameters,
|
||||
self._key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
self._client_parameters, self._key_set, args
|
||||
)
|
||||
# Server
|
||||
server_lambda = self._engine.load_server_lambda(
|
||||
self._compilation_result)
|
||||
public_result = self._engine.server_call(
|
||||
server_lambda, public_arguments)
|
||||
server_lambda = self._engine.load_server_lambda(self._compilation_result)
|
||||
public_result = self._engine.server_call(server_lambda, public_arguments)
|
||||
# Client
|
||||
return ClientSupport.decrypt_result(self._key_set, public_result)
|
||||
|
||||
|
||||
class ClientSupport:
|
||||
def key_set(client_parameters: ClientParameters, cache: KeySetCache = None) -> KeySet:
|
||||
def key_set(
|
||||
client_parameters: ClientParameters, cache: KeySetCache = None
|
||||
) -> KeySet:
|
||||
"""Generates a key set according to the given client parameters.
|
||||
If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache
|
||||
|
||||
@@ -181,7 +185,11 @@ class ClientSupport:
|
||||
"""
|
||||
return _ClientSupport.key_set(client_parameters, cache)
|
||||
|
||||
def encrypt_arguments(client_parameters: ClientParameters, key_set: KeySet, args: List[Union[int, np.ndarray]]) -> PublicArguments:
|
||||
def encrypt_arguments(
|
||||
client_parameters: ClientParameters,
|
||||
key_set: KeySet,
|
||||
args: List[Union[int, np.ndarray]],
|
||||
) -> PublicArguments:
|
||||
"""Export clear arguments to public arguments.
|
||||
For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object.
|
||||
|
||||
@@ -193,10 +201,15 @@ class ClientSupport:
|
||||
PublicArguments: the public arguments
|
||||
"""
|
||||
execution_arguments = [
|
||||
ClientSupport._create_execution_argument(arg) for arg in args]
|
||||
return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments)
|
||||
ClientSupport._create_execution_argument(arg) for arg in args
|
||||
]
|
||||
return _ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, execution_arguments
|
||||
)
|
||||
|
||||
def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]:
|
||||
def decrypt_result(
|
||||
key_set: KeySet, public_result: PublicResult
|
||||
) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result thanks the given key set.
|
||||
|
||||
Args:
|
||||
@@ -205,7 +218,7 @@ class ClientSupport:
|
||||
|
||||
Returns:
|
||||
int or numpy.array: The result of decryption.
|
||||
"""
|
||||
"""
|
||||
lambda_arg = _ClientSupport.decrypt_result(key_set, public_result)
|
||||
if lambda_arg.is_scalar():
|
||||
return lambda_arg.get_scalar()
|
||||
@@ -230,7 +243,8 @@ class ClientSupport:
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
|
||||
"value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
|
||||
)
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
@@ -242,8 +256,7 @@ class ClientSupport:
|
||||
if value.shape == ():
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError(
|
||||
"numpy.array must be of dtype uint{8,16,32,64}")
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
|
||||
|
||||
@@ -258,7 +271,11 @@ class JITCompilerSupport:
|
||||
)
|
||||
self._support = _JITLambdaSupport(runtime_lib_path)
|
||||
|
||||
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult:
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions("main"),
|
||||
) -> JitCompilationResult:
|
||||
"""JIT Compile a function define in the mlir_program to its homomorphic equivalent.
|
||||
|
||||
Args:
|
||||
@@ -272,7 +289,9 @@ class JITCompilerSupport:
|
||||
raise TypeError("mlir_program must be an `str`")
|
||||
return self._support.compile(mlir_program, options)
|
||||
|
||||
def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters:
|
||||
def load_client_parameters(
|
||||
self, compilation_result: JitCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
return self._support.load_client_parameters(compilation_result)
|
||||
|
||||
@@ -298,7 +317,11 @@ class LibraryCompilerSupport:
|
||||
self._library_path = outputPath
|
||||
self._support = _LibraryLambdaSupport(outputPath)
|
||||
|
||||
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> LibraryCompilationResult:
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions("main"),
|
||||
) -> LibraryCompilationResult:
|
||||
"""Compile a function define in the mlir_program to its homomorphic equivalent and save as library.
|
||||
|
||||
Args:
|
||||
@@ -327,19 +350,24 @@ class LibraryCompilerSupport:
|
||||
raise TypeError("func_name must be an `str`")
|
||||
return LibraryCompilationResult(self._library_path, func_name)
|
||||
|
||||
def load_client_parameters(self, compilation_result: LibraryCompilationResult) -> ClientParameters:
|
||||
def load_client_parameters(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result"""
|
||||
if not isinstance(compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
"compilation_result must be an `LibraryCompilationResult`")
|
||||
raise TypeError("compilation_result must be an `LibraryCompilationResult`")
|
||||
|
||||
return self._support.load_client_parameters(compilation_result)
|
||||
|
||||
def load_server_lambda(self, compilation_result: LibraryCompilationResult) -> LibraryLambda:
|
||||
def load_server_lambda(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> LibraryLambda:
|
||||
"""Load the server lambda from the JIT compilation result"""
|
||||
return self._support.load_server_lambda(compilation_result)
|
||||
|
||||
def server_call(self, server_lambda: LibraryLambda, public_arguments: PublicArguments) -> PublicResult:
|
||||
def server_call(
|
||||
self, server_lambda: LibraryLambda, public_arguments: PublicArguments
|
||||
) -> PublicResult:
|
||||
"""Call the server lambda with public_arguments
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
@@ -2,4 +2,12 @@
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
# We need this helpers from the mlir bindings, they are used in the generated files
|
||||
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context, get_op_result_or_value, get_op_results_or_values
|
||||
from mlir.dialects._ods_common import (
|
||||
_cext,
|
||||
segmented_accessor,
|
||||
equally_sized_accessor,
|
||||
extend_opview_class,
|
||||
get_default_loc_context,
|
||||
get_op_result_or_value,
|
||||
get_op_results_or_values,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
|
||||
from concrete.compiler import ClientSupport
|
||||
from concrete.compiler import KeySetCache
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH)
|
||||
|
||||
@@ -18,8 +18,7 @@ def compile_and_run(engine, mlir_input, args, expected_result):
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
@@ -199,8 +198,7 @@ end_to_end_fixture = [
|
||||
""",
|
||||
(
|
||||
np.array(
|
||||
[[31, 6, 12, 9], [31, 6, 12, 9], [
|
||||
31, 6, 12, 9], [31, 6, 12, 9]],
|
||||
[[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]],
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.array(
|
||||
@@ -245,28 +243,19 @@ end_to_end_fixture = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_jit_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = JITCompilerSupport()
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("py_test_lib_compile_and_run")
|
||||
compile_and_run(engine, mlir_input, args, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
end_to_end_fixture
|
||||
)
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
engine = LibraryCompilerSupport("test_lib_compile_reload_and_run")
|
||||
# Here don't save compilation result, reload
|
||||
@@ -275,8 +264,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keySetCacheTest)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
@@ -290,7 +278,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -307,13 +295,14 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
|
||||
)
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"):
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
|
||||
):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result, tab_size",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -333,12 +322,11 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
)
|
||||
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
|
||||
|
||||
@ pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -356,6 +344,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = CompilerEngine()
|
||||
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"cannot find the function for generate client parameters"
|
||||
):
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
|
||||
@@ -7,6 +7,7 @@ from concrete.compiler import CompilerEngine
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
|
||||
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
@@ -45,7 +46,7 @@ def test_compile_and_run_parallel(mlir_input, args, expected_result):
|
||||
engine.compile_fhe(
|
||||
mlir_input,
|
||||
unsecure_key_set_cache_path=KEY_SET_CACHE_PATH,
|
||||
auto_parallelize=True
|
||||
auto_parallelize=True,
|
||||
)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
|
||||
@@ -8,26 +8,27 @@ from test_compiler_file_output.utils import assert_exists, content, remove, run
|
||||
|
||||
TEST_PATH = os.path.dirname(__file__)
|
||||
|
||||
CCOMPILER = 'cc'
|
||||
CONCRETECOMPILER = 'concretecompiler'
|
||||
CCOMPILER = "cc"
|
||||
CONCRETECOMPILER = "concretecompiler"
|
||||
|
||||
SOURCE_1 = f'{TEST_PATH}/return_13.ir'
|
||||
SOURCE_2 = f'{TEST_PATH}/return_0.ir'
|
||||
SOURCE_C_1 = f'{TEST_PATH}/main_return_13.c'
|
||||
SOURCE_C_2 = f'{TEST_PATH}/main_return_0.c'
|
||||
OUTPUT = f'{TEST_PATH}/output.mlir'
|
||||
LIB = f'{TEST_PATH}/outlib'
|
||||
LIB_STATIC = LIB + '.a'
|
||||
DYNAMIC_LIB_EXT = '.dylib' if sys.platform == 'darwin' else '.so'
|
||||
SOURCE_1 = f"{TEST_PATH}/return_13.ir"
|
||||
SOURCE_2 = f"{TEST_PATH}/return_0.ir"
|
||||
SOURCE_C_1 = f"{TEST_PATH}/main_return_13.c"
|
||||
SOURCE_C_2 = f"{TEST_PATH}/main_return_0.c"
|
||||
OUTPUT = f"{TEST_PATH}/output.mlir"
|
||||
LIB = f"{TEST_PATH}/outlib"
|
||||
LIB_STATIC = LIB + ".a"
|
||||
DYNAMIC_LIB_EXT = ".dylib" if sys.platform == "darwin" else ".so"
|
||||
LIB_DYNAMIC = LIB + DYNAMIC_LIB_EXT
|
||||
LIBS = (LIB_STATIC, LIB_DYNAMIC)
|
||||
|
||||
assert_exists(SOURCE_1, SOURCE_2, SOURCE_C_1, SOURCE_C_2)
|
||||
|
||||
|
||||
def test_roundtrip():
|
||||
remove(OUTPUT)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, '--action=roundtrip', '-o', OUTPUT)
|
||||
run(CONCRETECOMPILER, SOURCE_1, "--action=roundtrip", "-o", OUTPUT)
|
||||
|
||||
assert_exists(OUTPUT)
|
||||
assert content(SOURCE_1) == content(OUTPUT)
|
||||
@@ -38,7 +39,7 @@ def test_roundtrip():
|
||||
def test_roundtrip_many():
|
||||
remove(OUTPUT)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=roundtrip', '-o', OUTPUT)
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=roundtrip", "-o", OUTPUT)
|
||||
|
||||
assert_exists(OUTPUT)
|
||||
assert f"{content(SOURCE_1)}{content(SOURCE_2)}" == content(OUTPUT)
|
||||
@@ -49,35 +50,36 @@ def test_roundtrip_many():
|
||||
def test_compile_library():
|
||||
remove(LIBS)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, '--action=compile', '-o', LIB)
|
||||
run(CONCRETECOMPILER, SOURCE_1, "--action=compile", "-o", LIB)
|
||||
|
||||
assert_exists(LIBS)
|
||||
|
||||
EXE = './main.exe'
|
||||
EXE = "./main.exe"
|
||||
remove(EXE)
|
||||
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_STATIC)
|
||||
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_STATIC)
|
||||
|
||||
result = subprocess.run([EXE], capture_output=True)
|
||||
assert 13 == result.returncode
|
||||
|
||||
remove(EXE)
|
||||
run(CCOMPILER, '-o', EXE, SOURCE_C_1, LIB_DYNAMIC)
|
||||
run(CCOMPILER, "-o", EXE, SOURCE_C_1, LIB_DYNAMIC)
|
||||
|
||||
result = subprocess.run([EXE], capture_output=True)
|
||||
assert 13 == result.returncode
|
||||
|
||||
remove(LIBS, EXE)
|
||||
|
||||
|
||||
def test_compile_many_library():
|
||||
remove(LIBS)
|
||||
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, '--action=compile', '-o', LIB)
|
||||
run(CONCRETECOMPILER, SOURCE_1, SOURCE_2, "--action=compile", "-o", LIB)
|
||||
|
||||
assert_exists(LIBS)
|
||||
|
||||
EXE = './main.exe'
|
||||
EXE = "./main.exe"
|
||||
remove(EXE)
|
||||
run(CCOMPILER, '-o', EXE, SOURCE_C_2, LIB_DYNAMIC)
|
||||
run(CCOMPILER, "-o", EXE, SOURCE_C_2, LIB_DYNAMIC)
|
||||
|
||||
result = subprocess.run([EXE], capture_output=True)
|
||||
assert 0 == result.returncode
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def on_paths(func, *paths):
|
||||
for path in paths:
|
||||
try:
|
||||
@@ -11,27 +12,32 @@ def on_paths(func, *paths):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def assert_exists(*paths):
|
||||
def func(path):
|
||||
if not os.path.exists(path):
|
||||
dirpath = os.path.dirname(path)
|
||||
if os.path.exists(dirpath):
|
||||
msg = f'{path} is not in {dirpath}'
|
||||
msg = f"{path} is not in {dirpath}"
|
||||
else:
|
||||
msg = f'{dirpath} does not exist for {path}'
|
||||
msg = f"{dirpath} does not exist for {path}"
|
||||
assert False, msg
|
||||
|
||||
on_paths(func, *paths)
|
||||
|
||||
|
||||
def remove(*paths):
|
||||
on_paths(os.remove, *paths)
|
||||
|
||||
|
||||
def content(path):
|
||||
with open(path) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def run(*cmd):
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
if result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert result.returncode == 0, ' '.join(cmd)
|
||||
return str(result.stdout, encoding='utf-8')
|
||||
assert result.returncode == 0, " ".join(cmd)
|
||||
return str(result.stdout, encoding="utf-8")
|
||||
|
||||
@@ -18,9 +18,7 @@ def test_eint_tensor(shape):
|
||||
register_dialects(ctx)
|
||||
eint = fhe.EncryptedIntegerType.get(ctx, 3)
|
||||
tensor = RankedTensorType.get(shape, eint)
|
||||
assert (
|
||||
tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
|
||||
)
|
||||
assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width", [0])
|
||||
|
||||
Reference in New Issue
Block a user