mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
61daa49e9d
commit
809ce28b38
@@ -132,6 +132,7 @@ def compile_numpy_function(
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
) -> CompilerEngine:
|
||||
"""Main API of hnumpy, to be able to compile an homomorphic program.
|
||||
|
||||
@@ -146,6 +147,8 @@ def compile_numpy_function(
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
show_mlir (bool): if set, the MLIR produced by the converter and which is going
|
||||
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine to run and debug the compiled graph
|
||||
@@ -171,6 +174,9 @@ def compile_numpy_function(
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
if show_mlir:
|
||||
print(f"MLIR which is going to be compiled: \n{mlir_result}")
|
||||
|
||||
# Compile the MLIR representation
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_result)
|
||||
|
||||
@@ -236,3 +236,33 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
f"\n==================\nExpected {ref_graph_str}"
|
||||
f"\n==================\n"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
|
||||
pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
|
||||
"""Test show_mlir option"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
show_mlir=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user