mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 00:21:36 -05:00
feat(compiler): support multi-circuit compilation
This commit is contained in:
committed by
Alexandre Péré
parent
3247a28d9d
commit
9b5a2e46da
@@ -7,7 +7,8 @@ from concrete.compiler import (
|
||||
LibrarySupport,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
CompilationFeedback,
|
||||
ProgramCompilationFeedback,
|
||||
CircuitCompilationFeedback,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,31 +24,40 @@ def assert_result(result, expected_result):
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
def run(engine, args, compilation_result, keyset_cache):
|
||||
def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
|
||||
"""Execute engine on the given arguments.
|
||||
|
||||
Perform required loading, encryption, execution, and decryption."""
|
||||
# Dev
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
|
||||
assert isinstance(compilation_feedback.complexity, float)
|
||||
assert isinstance(compilation_feedback.p_error, float)
|
||||
assert isinstance(compilation_feedback.global_p_error, float)
|
||||
assert isinstance(compilation_feedback.total_secret_keys_size, int)
|
||||
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
|
||||
assert isinstance(compilation_feedback.total_inputs_size, int)
|
||||
assert isinstance(compilation_feedback.total_output_size, int)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks, list)
|
||||
circuit_feedback = next(
|
||||
filter(lambda x: x.name == circuit_name, compilation_feedback.circuit_feedbacks)
|
||||
)
|
||||
assert isinstance(circuit_feedback, CircuitCompilationFeedback)
|
||||
assert isinstance(circuit_feedback.total_inputs_size, int)
|
||||
assert isinstance(circuit_feedback.total_output_size, int)
|
||||
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args, circuit_name
|
||||
)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result, False)
|
||||
server_lambda = engine.load_server_lambda(compilation_result, False, circuit_name)
|
||||
evaluation_keys = key_set.get_evaluation_keys()
|
||||
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(client_parameters, key_set, public_result)
|
||||
result = ClientSupport.decrypt_result(
|
||||
client_parameters, key_set, public_result, circuit_name
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -57,11 +67,12 @@ def compile_run_assert(
|
||||
args,
|
||||
expected_result,
|
||||
keyset_cache,
|
||||
options=CompilationOptions.new("main"),
|
||||
options=CompilationOptions.new(),
|
||||
circuit_name="main",
|
||||
):
|
||||
"""Compile run and assert result."""
|
||||
compilation_result = engine.compile(mlir_input, options)
|
||||
result = run(engine, args, compilation_result, keyset_cache)
|
||||
result = run(engine, args, compilation_result, keyset_cache, circuit_name)
|
||||
assert_result(result, expected_result)
|
||||
|
||||
|
||||
@@ -231,6 +242,33 @@ def test_lib_compilation_artifacts():
|
||||
assert not os.path.exists(engine.get_shared_lib_path())
|
||||
|
||||
|
||||
def test_multi_circuits(keyset_cache):
|
||||
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
|
||||
|
||||
mlir_str = """
|
||||
func.func @add(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
func.func @sub(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
"""
|
||||
args = (10, 3)
|
||||
expected_add_result = 13
|
||||
expected_sub_result = 7
|
||||
engine = LibrarySupport.new("./py_test_multi_circuits")
|
||||
options = CompilationOptions.new()
|
||||
options.set_optimizer_strategy(OptimizerStrategy.V0)
|
||||
compile_run_assert(
|
||||
engine, mlir_str, args, expected_add_result, keyset_cache, options, "add"
|
||||
)
|
||||
compile_run_assert(
|
||||
engine, mlir_str, args, expected_sub_result, keyset_cache, options, "sub"
|
||||
)
|
||||
|
||||
|
||||
def _test_lib_compile_and_run_with_options(keyset_cache, options):
|
||||
mlir_input = """
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
@@ -247,21 +285,21 @@ def _test_lib_compile_and_run_with_options(keyset_cache, options):
|
||||
|
||||
|
||||
def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
_test_lib_compile_and_run_with_options(keyset_cache, options)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_global_p_error(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_global_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
_test_lib_compile_and_run_with_options(keyset_cache, options)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_security_level(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_security_level(80)
|
||||
options.set_display_optimizer_choice(True)
|
||||
_test_lib_compile_and_run_with_options(keyset_cache, options)
|
||||
@@ -276,7 +314,7 @@ def test_compile_and_run_auto_parallelize(
|
||||
):
|
||||
artifact_dir = "./py_test_compile_and_run_auto_parallelize"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_auto_parallelize(True)
|
||||
compile_run_assert(
|
||||
engine, mlir_input, args, expected_result, keyset_cache, options=options
|
||||
@@ -301,7 +339,7 @@ def test_compile_and_run_auto_parallelize(
|
||||
# if no_parallel:
|
||||
# artifact_dir = "./py_test_compile_dataflow_and_fail_run"
|
||||
# engine = LibrarySupport.new(artifact_dir)
|
||||
# options = CompilationOptions.new("main")
|
||||
# options = CompilationOptions.new()
|
||||
# options.set_auto_parallelize(True)
|
||||
# with pytest.raises(
|
||||
# RuntimeError,
|
||||
@@ -334,7 +372,7 @@ def test_compile_and_run_loop_parallelize(
|
||||
):
|
||||
artifact_dir = "./py_test_compile_and_run_loop_parallelize"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
options = CompilationOptions.new("main")
|
||||
options = CompilationOptions.new()
|
||||
options.set_loop_parallelize(True)
|
||||
compile_run_assert(
|
||||
engine, mlir_input, args, expected_result, keyset_cache, options=options
|
||||
@@ -365,29 +403,6 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
|
||||
compile_run_assert(engine, mlir_input, args, None, keyset_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input",
|
||||
[
|
||||
pytest.param(
|
||||
"""
|
||||
func.func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7>
|
||||
return %ret : !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
id="not @main",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
artifact_dir = "./py_test_compile_invalid"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
|
||||
engine.compile(mlir_input)
|
||||
|
||||
|
||||
def test_crt_decomposition_feedback():
|
||||
mlir = """
|
||||
|
||||
@@ -401,11 +416,27 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
|
||||
artifact_dir = "./py_test_crt_decomposition_feedback"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
|
||||
compilation_result = engine.compile(mlir, options=CompilationOptions.new())
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]
|
||||
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
|
||||
assert isinstance(compilation_feedback.complexity, float)
|
||||
assert isinstance(compilation_feedback.p_error, float)
|
||||
assert isinstance(compilation_feedback.global_p_error, float)
|
||||
assert isinstance(compilation_feedback.total_secret_keys_size, int)
|
||||
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks, list)
|
||||
assert isinstance(
|
||||
compilation_feedback.circuit_feedbacks[0], CircuitCompilationFeedback
|
||||
)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_inputs_size, int)
|
||||
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_output_size, int)
|
||||
assert isinstance(
|
||||
compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs, list
|
||||
)
|
||||
assert compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs == [
|
||||
[7, 8, 9, 11, 13]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -450,10 +481,11 @@ def test_memory_usage(mlir: str, expected_memory_usage_per_loc: dict):
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compilation_result = engine.compile(mlir)
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
|
||||
|
||||
assert (
|
||||
expected_memory_usage_per_loc == compilation_feedback.memory_usage_per_location
|
||||
expected_memory_usage_per_loc
|
||||
== compilation_feedback.circuit_feedbacks[0].memory_usage_per_location
|
||||
)
|
||||
|
||||
shutil.rmtree(artifact_dir)
|
||||
|
||||
@@ -49,7 +49,7 @@ def compile_run_assert(
|
||||
mlir_input,
|
||||
args_and_shape,
|
||||
expected_result,
|
||||
options=CompilationOptions.new("main"),
|
||||
options=CompilationOptions.new(),
|
||||
):
|
||||
# compile with simulation
|
||||
options.simulation(True)
|
||||
|
||||
@@ -32,7 +32,10 @@ module {
|
||||
compilation_result = support.compile(mlir)
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
compilation_feedback = support.load_compilation_feedback(compilation_result)
|
||||
program_compilation_feedback = support.load_compilation_feedback(
|
||||
compilation_result
|
||||
)
|
||||
compilation_feedback = program_compilation_feedback.circuit("main")
|
||||
|
||||
pbs_count = compilation_feedback.count(
|
||||
operations={
|
||||
|
||||
Reference in New Issue
Block a user