feat(compiler): introduce concrete-protocol

This commit:
 + Adds support for a protocol which enables inter-op between concrete,
   tfhe-rs and potentially other contributors to the fhe ecosystem.
 + Gets rid of hand-made serialization in the compiler, and
   client/server libs.
 + Refactors client/server libs to allow more pre/post processing of
   circuit inputs/outputs.

The protocol is supported by a definition in the shape of a capnp file,
which defines different types of objects among which:
 + ProgramInfo object, which is a precise description of a set of fhe
   circuit coming from the same compilation (understand function type
   information), and the associated key set.
 + *Key objects, which represent secret/public keys used to
   encrypt/execute fhe circuits.
 + Value object, which represent values that can be transferred between
   client and server to support calls to fhe circuits.

The hand-rolled serialization that was previously used is completely
dropped in favor of capnp in the whole codebase.

The client/server libs, are refactored to introduce a modular design for
pre-post processing. Reading the ProgramInfo file associated with a
compilation, the client and server libs assemble a pipeline of
transformers (functions) for pre and post processing of values coming in
and out of a circuit. This design properly decouples various aspects of
the processing, and allows these capabilities to be safely extended.

In practice this commit includes the following:
 + Defines the specification in a concreteprotocol package
 + Integrate the compilation of this package as a compiler dependency
   via cmake
 + Modify the compiler to use the Encodings objects defined in the
   protocol
 + Modify the compiler to emit ProgramInfo files as compilation
   artifact, and gets rid of the bloated ClientParameters.
 + Introduces a new Common library containing the functionalities shared
   between the compiler and the client/server libs.
 + Introduces a functional pre-post processing pipeline to this common
   library
 + Modify the client/server libs to support loading ProgramInfo objects,
   and calling circuits using Value messages.
 + Drops support of JIT.
 + Drops support of C-api.
 + Drops support of Rust bindings.

Co-authored-by: Nikita Frolov <nf@mkmks.org>
This commit is contained in:
Alexandre Péré
2023-10-27 16:32:40 +02:00
committed by Alexandre Péré
parent 9139101cc3
commit e8ef48ffd8
207 changed files with 8601 additions and 16816 deletions

View File

@@ -36,7 +36,7 @@ def test_accepted_ints(value):
except Exception:
pytest.fail(f"value of type {type(value)} should be supported")
assert arg.is_scalar(), "should have been a scalar"
assert arg.get_scalar() == value
assert arg.get_signed_scalar() == value
# TODO: #495
@@ -60,8 +60,8 @@ def test_accepted_ndarray(dtype, maxvalue):
assert np.all(np.equal(arg.get_tensor_shape(), value.shape))
assert np.all(
np.equal(
value,
np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()),
value.astype(np.int64),
np.array(arg.get_signed_tensor_data()).reshape(arg.get_tensor_shape()),
)
)
@@ -73,4 +73,4 @@ def test_accepted_array_as_scalar():
except Exception:
pytest.fail(f"value of type {type(value)} should be supported")
assert arg.is_scalar(), "should have been a scalar"
assert arg.get_scalar() == value
assert arg.get_signed_scalar() == value

View File

@@ -52,7 +52,7 @@ func.func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([1, 2, 3, 4], dtype=np.uint64),
np.array([4, 3, 2, 1], dtype=np.uint8),
),
20,
@@ -69,8 +69,8 @@ func.func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> ten
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([7, 0, 1, 5], dtype=np.uint8),
np.array([1, 2, 3, 4], dtype=np.uint64),
np.array([7, 0, 1, 5], dtype=np.uint64),
),
np.array([8, 2, 4, 9]),
id="enc_enc_ndarray_args",
@@ -81,7 +81,7 @@ def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):
with tempfile.TemporaryDirectory() as tmpdirname:
support = LibrarySupport.new(str(tmpdirname))
compilation_result = support.compile(mlir)
server_lambda = support.load_server_lambda(compilation_result)
server_lambda = support.load_server_lambda(compilation_result, False)
client_parameters = support.load_client_parameters(compilation_result)
keyset = ClientSupport.key_set(client_parameters, keyset_cache)

View File

@@ -4,7 +4,6 @@ import os.path
import shutil
import numpy as np
from concrete.compiler import (
JITSupport,
LibrarySupport,
ClientSupport,
CompilationOptions,
@@ -44,7 +43,7 @@ def run(engine, args, compilation_result, keyset_cache):
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
server_lambda = engine.load_server_lambda(compilation_result, False)
evaluation_keys = key_set.get_evaluation_keys()
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
# Client
@@ -60,10 +59,7 @@ def compile_run_assert(
keyset_cache,
options=CompilationOptions.new("main"),
):
"""Compile run and assert result.
Can take both JITSupport or LibrarySupport as engine.
"""
"""Compile run and assert result."""
compilation_result = engine.compile(mlir_input, options)
result = run(engine, args, compilation_result, keyset_cache)
assert_result(result, expected_result)
@@ -88,7 +84,7 @@ end_to_end_fixture = [
return %1: !FHE.eint<7>
}
""",
(np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)),
(np.array(4, dtype=np.int64), np.array(5, dtype=np.uint8)),
9,
id="add_eint_int_with_ndarray_as_scalar",
),
@@ -197,12 +193,6 @@ end_to_end_parallel_fixture = [
]
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache):
engine = JITSupport.new()
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache):
artifact_dir = "./py_test_lib_compile_and_run"
@@ -234,10 +224,10 @@ def test_lib_compilation_artifacts():
artifact_dir = "./test_artifacts"
engine = LibrarySupport.new(artifact_dir)
engine.compile(mlir_str)
assert os.path.exists(engine.get_client_parameters_path())
assert os.path.exists(engine.get_program_info_path())
assert os.path.exists(engine.get_shared_lib_path())
shutil.rmtree(artifact_dir)
assert not os.path.exists(engine.get_client_parameters_path())
assert not os.path.exists(engine.get_program_info_path())
assert not os.path.exists(engine.get_shared_lib_path())
@@ -281,17 +271,11 @@ def test_lib_compile_and_run_security_level(keyset_cache):
@pytest.mark.parametrize(
"mlir_input, args, expected_result", end_to_end_parallel_fixture
)
@pytest.mark.parametrize(
"EngineClass",
[
pytest.param(JITSupport, id="JIT"),
pytest.param(LibrarySupport, id="Library"),
],
)
def test_compile_and_run_auto_parallelize(
mlir_input, args, expected_result, keyset_cache, EngineClass
mlir_input, args, expected_result, keyset_cache
):
engine = EngineClass.new()
artifact_dir = "./py_test_compile_and_run_auto_parallelize"
engine = LibrarySupport.new(artifact_dir)
options = CompilationOptions.new("main")
options.set_auto_parallelize(True)
compile_run_assert(
@@ -299,28 +283,33 @@ def test_compile_and_run_auto_parallelize(
)
# FIXME #51
@pytest.mark.xfail(
platform.system() == "Darwin",
reason="MacOS have issues with translating Cpp exceptions",
)
@pytest.mark.parametrize(
"mlir_input, args, expected_result", end_to_end_parallel_fixture
)
def test_compile_dataflow_and_fail_run(
mlir_input, args, expected_result, keyset_cache, no_parallel
):
if no_parallel:
engine = JITSupport.new()
options = CompilationOptions.new("main")
options.set_auto_parallelize(True)
with pytest.raises(
RuntimeError,
match="call: current runtime doesn't support dataflow execution",
):
compile_run_assert(
engine, mlir_input, args, expected_result, keyset_cache, options=options
)
# This test was running in JIT mode at first. Problem is now, it does not work with the library
# support. It is not clear to me why, but the dataflow runtime seems to have stuffs dedicated to
# the dropped JIT support... I am cancelling it until further explored.
#
# # FIXME #51
# @pytest.mark.xfail(
# platform.system() == "Darwin",
# reason="MacOS have issues with translating Cpp exceptions",
# )
# @pytest.mark.parametrize(
# "mlir_input, args, expected_result", end_to_end_parallel_fixture
# )
# def test_compile_dataflow_and_fail_run(
# mlir_input, args, expected_result, keyset_cache, no_parallel
# ):
# if no_parallel:
# artifact_dir = "./py_test_compile_dataflow_and_fail_run"
# engine = LibrarySupport.new(artifact_dir)
# options = CompilationOptions.new("main")
# options.set_auto_parallelize(True)
# with pytest.raises(
# RuntimeError,
# match="call: current runtime doesn't support dataflow execution",
# ):
# compile_run_assert(
# engine, mlir_input, args, expected_result, keyset_cache, options=options
# )
@pytest.mark.parametrize(
@@ -340,17 +329,11 @@ def test_compile_dataflow_and_fail_run(
),
],
)
@pytest.mark.parametrize(
"EngineClass",
[
pytest.param(JITSupport, id="JIT"),
pytest.param(LibrarySupport, id="Library"),
],
)
def test_compile_and_run_loop_parallelize(
mlir_input, args, expected_result, keyset_cache, EngineClass
mlir_input, args, expected_result, keyset_cache
):
engine = EngineClass.new()
artifact_dir = "./py_test_compile_and_run_loop_parallelize"
engine = LibrarySupport.new(artifact_dir)
options = CompilationOptions.new("main")
options.set_loop_parallelize(True)
compile_run_assert(
@@ -378,17 +361,9 @@ def test_compile_and_run_loop_parallelize(
),
],
)
@pytest.mark.parametrize(
"EngineClass",
[
pytest.param(JITSupport, id="JIT"),
pytest.param(LibrarySupport, id="Library"),
],
)
def test_compile_and_run_invalid_arg_number(
mlir_input, args, EngineClass, keyset_cache
):
engine = EngineClass.new()
def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
artifact_dir = "./py_test_compile_and_run_invalid_arg_number"
engine = LibrarySupport.new(artifact_dir)
with pytest.raises(
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
):
@@ -417,7 +392,8 @@ def test_compile_and_run_invalid_arg_number(
],
)
def test_compile_invalid(mlir_input):
engine = JITSupport.new()
artifact_dir = "./py_test_compile_invalid"
engine = LibrarySupport.new(artifact_dir)
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
engine.compile(mlir_input)
@@ -433,7 +409,8 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
"""
engine = JITSupport.new()
artifact_dir = "./py_test_crt_decomposition_feedback"
engine = LibrarySupport.new(artifact_dir)
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
compilation_feedback = engine.load_compilation_feedback(compilation_result)

View File

@@ -30,7 +30,7 @@ module {
support = LibrarySupport.new(str(tmpdirname))
compilation_result = support.compile(mlir)
server_lambda = support.load_server_lambda(compilation_result)
server_lambda = support.load_server_lambda(compilation_result, False)
client_parameters = support.load_client_parameters(compilation_result)
keyset = ClientSupport.key_set(client_parameters)

View File

@@ -37,7 +37,7 @@ def run_simulated(engine, args_and_shape, compilation_result):
values.append(sim_value_exporter.export_tensor(pos, arg, shape))
pos += 1
public_arguments = PublicArguments.new(client_parameters, values)
server_lambda = engine.load_server_lambda(compilation_result)
server_lambda = engine.load_server_lambda(compilation_result, True)
public_result = engine.simulate(server_lambda, public_arguments)
sim_value_decrypter = SimulatedValueDecrypter.new(client_parameters)
result = sim_value_decrypter.decrypt(0, public_result.get_value(0))

View File

@@ -3,9 +3,6 @@ from concrete.compiler import (
ClientParameters,
ClientSupport,
CompilationOptions,
JITCompilationResult,
JITLambda,
JITSupport,
KeySetCache,
KeySet,
LambdaArgument,
@@ -24,9 +21,6 @@ from concrete.compiler import (
pytest.param(ClientParameters, id="ClientParameters"),
pytest.param(ClientSupport, id="ClientSupport"),
pytest.param(CompilationOptions, id="CompilationOptions"),
pytest.param(JITCompilationResult, id="JITCompilationResult"),
pytest.param(JITLambda, id="JITLambda"),
pytest.param(JITSupport, id="JITSupport"),
pytest.param(KeySetCache, id="KeySetCache"),
pytest.param(KeySet, id="KeySet"),
pytest.param(LambdaArgument, id="LambdaArgument"),