feat(compiler): support multi-circuit compilation

This commit is contained in:
Alexandre Péré
2024-02-09 11:39:06 +01:00
committed by Alexandre Péré
parent 3247a28d9d
commit 9b5a2e46da
78 changed files with 1200 additions and 865 deletions

View File

@@ -7,7 +7,8 @@ from concrete.compiler import (
LibrarySupport,
ClientSupport,
CompilationOptions,
CompilationFeedback,
ProgramCompilationFeedback,
CircuitCompilationFeedback,
)
@@ -23,31 +24,40 @@ def assert_result(result, expected_result):
assert np.all(result == expected_result)
def run(engine, args, compilation_result, keyset_cache):
def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
"""Execute engine on the given arguments.
Perform required loading, encryption, execution, and decryption."""
# Dev
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
assert isinstance(compilation_feedback.complexity, float)
assert isinstance(compilation_feedback.p_error, float)
assert isinstance(compilation_feedback.global_p_error, float)
assert isinstance(compilation_feedback.total_secret_keys_size, int)
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
assert isinstance(compilation_feedback.total_inputs_size, int)
assert isinstance(compilation_feedback.total_output_size, int)
assert isinstance(compilation_feedback.circuit_feedbacks, list)
circuit_feedback = next(
filter(lambda x: x.name == circuit_name, compilation_feedback.circuit_feedbacks)
)
assert isinstance(circuit_feedback, CircuitCompilationFeedback)
assert isinstance(circuit_feedback.total_inputs_size, int)
assert isinstance(circuit_feedback.total_output_size, int)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args, circuit_name
)
# Server
server_lambda = engine.load_server_lambda(compilation_result, False)
server_lambda = engine.load_server_lambda(compilation_result, False, circuit_name)
evaluation_keys = key_set.get_evaluation_keys()
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
# Client
result = ClientSupport.decrypt_result(client_parameters, key_set, public_result)
result = ClientSupport.decrypt_result(
client_parameters, key_set, public_result, circuit_name
)
return result
@@ -57,11 +67,12 @@ def compile_run_assert(
args,
expected_result,
keyset_cache,
options=CompilationOptions.new("main"),
options=CompilationOptions.new(),
circuit_name="main",
):
"""Compile run and assert result."""
compilation_result = engine.compile(mlir_input, options)
result = run(engine, args, compilation_result, keyset_cache)
result = run(engine, args, compilation_result, keyset_cache, circuit_name)
assert_result(result, expected_result)
@@ -231,6 +242,33 @@ def test_lib_compilation_artifacts():
assert not os.path.exists(engine.get_shared_lib_path())
def test_multi_circuits(keyset_cache):
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
mlir_str = """
func.func @add(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
func.func @sub(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
"""
args = (10, 3)
expected_add_result = 13
expected_sub_result = 7
engine = LibrarySupport.new("./py_test_multi_circuits")
options = CompilationOptions.new()
options.set_optimizer_strategy(OptimizerStrategy.V0)
compile_run_assert(
engine, mlir_str, args, expected_add_result, keyset_cache, options, "add"
)
compile_run_assert(
engine, mlir_str, args, expected_sub_result, keyset_cache, options, "sub"
)
def _test_lib_compile_and_run_with_options(keyset_cache, options):
mlir_input = """
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
@@ -247,21 +285,21 @@ def _test_lib_compile_and_run_with_options(keyset_cache, options):
def test_lib_compile_and_run_p_error(keyset_cache):
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_p_error(0.00001)
options.set_display_optimizer_choice(True)
_test_lib_compile_and_run_with_options(keyset_cache, options)
def test_lib_compile_and_run_global_p_error(keyset_cache):
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_global_p_error(0.00001)
options.set_display_optimizer_choice(True)
_test_lib_compile_and_run_with_options(keyset_cache, options)
def test_lib_compile_and_run_security_level(keyset_cache):
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_security_level(80)
options.set_display_optimizer_choice(True)
_test_lib_compile_and_run_with_options(keyset_cache, options)
@@ -276,7 +314,7 @@ def test_compile_and_run_auto_parallelize(
):
artifact_dir = "./py_test_compile_and_run_auto_parallelize"
engine = LibrarySupport.new(artifact_dir)
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_auto_parallelize(True)
compile_run_assert(
engine, mlir_input, args, expected_result, keyset_cache, options=options
@@ -301,7 +339,7 @@ def test_compile_and_run_auto_parallelize(
# if no_parallel:
# artifact_dir = "./py_test_compile_dataflow_and_fail_run"
# engine = LibrarySupport.new(artifact_dir)
# options = CompilationOptions.new("main")
# options = CompilationOptions.new()
# options.set_auto_parallelize(True)
# with pytest.raises(
# RuntimeError,
@@ -334,7 +372,7 @@ def test_compile_and_run_loop_parallelize(
):
artifact_dir = "./py_test_compile_and_run_loop_parallelize"
engine = LibrarySupport.new(artifact_dir)
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_loop_parallelize(True)
compile_run_assert(
engine, mlir_input, args, expected_result, keyset_cache, options=options
@@ -365,29 +403,6 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
compile_run_assert(engine, mlir_input, args, None, keyset_cache)
@pytest.mark.parametrize(
"mlir_input",
[
pytest.param(
"""
func.func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7>
return %ret : !FHE.eint<7>
}
""",
id="not @main",
),
],
)
def test_compile_invalid(mlir_input):
artifact_dir = "./py_test_compile_invalid"
engine = LibrarySupport.new(artifact_dir)
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
engine.compile(mlir_input)
def test_crt_decomposition_feedback():
mlir = """
@@ -401,11 +416,27 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
artifact_dir = "./py_test_crt_decomposition_feedback"
engine = LibrarySupport.new(artifact_dir)
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
compilation_result = engine.compile(mlir, options=CompilationOptions.new())
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
assert isinstance(compilation_feedback.complexity, float)
assert isinstance(compilation_feedback.p_error, float)
assert isinstance(compilation_feedback.global_p_error, float)
assert isinstance(compilation_feedback.total_secret_keys_size, int)
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
assert isinstance(compilation_feedback.circuit_feedbacks, list)
assert isinstance(
compilation_feedback.circuit_feedbacks[0], CircuitCompilationFeedback
)
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_inputs_size, int)
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_output_size, int)
assert isinstance(
compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs, list
)
assert compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs == [
[7, 8, 9, 11, 13]
]
@pytest.mark.parametrize(
@@ -450,10 +481,11 @@ def test_memory_usage(mlir: str, expected_memory_usage_per_loc: dict):
engine = LibrarySupport.new(artifact_dir)
compilation_result = engine.compile(mlir)
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
assert (
expected_memory_usage_per_loc == compilation_feedback.memory_usage_per_location
expected_memory_usage_per_loc
== compilation_feedback.circuit_feedbacks[0].memory_usage_per_location
)
shutil.rmtree(artifact_dir)

View File

@@ -49,7 +49,7 @@ def compile_run_assert(
mlir_input,
args_and_shape,
expected_result,
options=CompilationOptions.new("main"),
options=CompilationOptions.new(),
):
# compile with simulation
options.simulation(True)

View File

@@ -32,7 +32,10 @@ module {
compilation_result = support.compile(mlir)
client_parameters = support.load_client_parameters(compilation_result)
compilation_feedback = support.load_compilation_feedback(compilation_result)
program_compilation_feedback = support.load_compilation_feedback(
compilation_result
)
compilation_feedback = program_compilation_feedback.circuit("main")
pbs_count = compilation_feedback.count(
operations={