mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 08:31:31 -05:00
feat(compiler): introduce concrete-protocol
This commit: + Adds support for a protocol which enables inter-op between concrete, tfhe-rs and potentially other contributors to the fhe ecosystem. + Gets rid of hand-made serialization in the compiler, and client/server libs. + Refactors client/server libs to allow more pre/post processing of circuit inputs/outputs. The protocol is supported by a definition in the shape of a capnp file, which defines different types of objects among which: + ProgramInfo object, which is a precise description of a set of fhe circuit coming from the same compilation (understand function type information), and the associated key set. + *Key objects, which represent secret/public keys used to encrypt/execute fhe circuits. + Value object, which represent values that can be transferred between client and server to support calls to fhe circuits. The hand-rolled serialization that was previously used is completely dropped in favor of capnp in the whole codebase. The client/server libs, are refactored to introduce a modular design for pre-post processing. Reading the ProgramInfo file associated with a compilation, the client and server libs assemble a pipeline of transformers (functions) for pre and post processing of values coming in and out of a circuit. This design properly decouples various aspects of the processing, and allows these capabilities to be safely extended. In practice this commit includes the following: + Defines the specification in a concreteprotocol package + Integrate the compilation of this package as a compiler dependency via cmake + Modify the compiler to use the Encodings objects defined in the protocol + Modify the compiler to emit ProgramInfo files as compilation artifact, and gets rid of the bloated ClientParameters. + Introduces a new Common library containing the functionalities shared between the compiler and the client/server libs. + Introduces a functional pre-post processing pipeline to this common library + Modify the client/server libs to support loading ProgramInfo objects, and calling circuits using Value messages. + Drops support of JIT. + Drops support of C-api. + Drops support of Rust bindings. Co-authored-by: Nikita Frolov <nf@mkmks.org>
This commit is contained in:
committed by
Alexandre Péré
parent
9139101cc3
commit
e8ef48ffd8
@@ -36,7 +36,7 @@ def test_accepted_ints(value):
|
||||
except Exception:
|
||||
pytest.fail(f"value of type {type(value)} should be supported")
|
||||
assert arg.is_scalar(), "should have been a scalar"
|
||||
assert arg.get_scalar() == value
|
||||
assert arg.get_signed_scalar() == value
|
||||
|
||||
|
||||
# TODO: #495
|
||||
@@ -60,8 +60,8 @@ def test_accepted_ndarray(dtype, maxvalue):
|
||||
assert np.all(np.equal(arg.get_tensor_shape(), value.shape))
|
||||
assert np.all(
|
||||
np.equal(
|
||||
value,
|
||||
np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()),
|
||||
value.astype(np.int64),
|
||||
np.array(arg.get_signed_tensor_data()).reshape(arg.get_tensor_shape()),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -73,4 +73,4 @@ def test_accepted_array_as_scalar():
|
||||
except Exception:
|
||||
pytest.fail(f"value of type {type(value)} should be supported")
|
||||
assert arg.is_scalar(), "should have been a scalar"
|
||||
assert arg.get_scalar() == value
|
||||
assert arg.get_signed_scalar() == value
|
||||
|
||||
@@ -52,7 +52,7 @@ func.func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint
|
||||
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([1, 2, 3, 4], dtype=np.uint64),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
20,
|
||||
@@ -69,8 +69,8 @@ func.func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> ten
|
||||
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([7, 0, 1, 5], dtype=np.uint8),
|
||||
np.array([1, 2, 3, 4], dtype=np.uint64),
|
||||
np.array([7, 0, 1, 5], dtype=np.uint64),
|
||||
),
|
||||
np.array([8, 2, 4, 9]),
|
||||
id="enc_enc_ndarray_args",
|
||||
@@ -81,7 +81,7 @@ def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
support = LibrarySupport.new(str(tmpdirname))
|
||||
compilation_result = support.compile(mlir)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
server_lambda = support.load_server_lambda(compilation_result, False)
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
keyset = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
|
||||
@@ -4,7 +4,6 @@ import os.path
|
||||
import shutil
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
JITSupport,
|
||||
LibrarySupport,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
@@ -44,7 +43,7 @@ def run(engine, args, compilation_result, keyset_cache):
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
server_lambda = engine.load_server_lambda(compilation_result, False)
|
||||
evaluation_keys = key_set.get_evaluation_keys()
|
||||
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
|
||||
# Client
|
||||
@@ -60,10 +59,7 @@ def compile_run_assert(
|
||||
keyset_cache,
|
||||
options=CompilationOptions.new("main"),
|
||||
):
|
||||
"""Compile run and assert result.
|
||||
|
||||
Can take both JITSupport or LibrarySupport as engine.
|
||||
"""
|
||||
"""Compile run and assert result."""
|
||||
compilation_result = engine.compile(mlir_input, options)
|
||||
result = run(engine, args, compilation_result, keyset_cache)
|
||||
assert_result(result, expected_result)
|
||||
@@ -88,7 +84,7 @@ end_to_end_fixture = [
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)),
|
||||
(np.array(4, dtype=np.int64), np.array(5, dtype=np.uint8)),
|
||||
9,
|
||||
id="add_eint_int_with_ndarray_as_scalar",
|
||||
),
|
||||
@@ -197,12 +193,6 @@ end_to_end_parallel_fixture = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache):
|
||||
engine = JITSupport.new()
|
||||
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache):
|
||||
artifact_dir = "./py_test_lib_compile_and_run"
|
||||
@@ -234,10 +224,10 @@ def test_lib_compilation_artifacts():
|
||||
artifact_dir = "./test_artifacts"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
engine.compile(mlir_str)
|
||||
assert os.path.exists(engine.get_client_parameters_path())
|
||||
assert os.path.exists(engine.get_program_info_path())
|
||||
assert os.path.exists(engine.get_shared_lib_path())
|
||||
shutil.rmtree(artifact_dir)
|
||||
assert not os.path.exists(engine.get_client_parameters_path())
|
||||
assert not os.path.exists(engine.get_program_info_path())
|
||||
assert not os.path.exists(engine.get_shared_lib_path())
|
||||
|
||||
|
||||
@@ -281,17 +271,11 @@ def test_lib_compile_and_run_security_level(keyset_cache):
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result", end_to_end_parallel_fixture
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"EngineClass",
|
||||
[
|
||||
pytest.param(JITSupport, id="JIT"),
|
||||
pytest.param(LibrarySupport, id="Library"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_auto_parallelize(
|
||||
mlir_input, args, expected_result, keyset_cache, EngineClass
|
||||
mlir_input, args, expected_result, keyset_cache
|
||||
):
|
||||
engine = EngineClass.new()
|
||||
artifact_dir = "./py_test_compile_and_run_auto_parallelize"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_auto_parallelize(True)
|
||||
compile_run_assert(
|
||||
@@ -299,28 +283,33 @@ def test_compile_and_run_auto_parallelize(
|
||||
)
|
||||
|
||||
|
||||
# FIXME #51
|
||||
@pytest.mark.xfail(
|
||||
platform.system() == "Darwin",
|
||||
reason="MacOS have issues with translating Cpp exceptions",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result", end_to_end_parallel_fixture
|
||||
)
|
||||
def test_compile_dataflow_and_fail_run(
|
||||
mlir_input, args, expected_result, keyset_cache, no_parallel
|
||||
):
|
||||
if no_parallel:
|
||||
engine = JITSupport.new()
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_auto_parallelize(True)
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="call: current runtime doesn't support dataflow execution",
|
||||
):
|
||||
compile_run_assert(
|
||||
engine, mlir_input, args, expected_result, keyset_cache, options=options
|
||||
)
|
||||
# This test was running in JIT mode at first. Problem is now, it does not work with the library
|
||||
# support. It is not clear to me why, but the dataflow runtime seems to have stuffs dedicated to
|
||||
# the dropped JIT support... I am cancelling it until further explored.
|
||||
#
|
||||
# # FIXME #51
|
||||
# @pytest.mark.xfail(
|
||||
# platform.system() == "Darwin",
|
||||
# reason="MacOS have issues with translating Cpp exceptions",
|
||||
# )
|
||||
# @pytest.mark.parametrize(
|
||||
# "mlir_input, args, expected_result", end_to_end_parallel_fixture
|
||||
# )
|
||||
# def test_compile_dataflow_and_fail_run(
|
||||
# mlir_input, args, expected_result, keyset_cache, no_parallel
|
||||
# ):
|
||||
# if no_parallel:
|
||||
# artifact_dir = "./py_test_compile_dataflow_and_fail_run"
|
||||
# engine = LibrarySupport.new(artifact_dir)
|
||||
# options = CompilationOptions.new("main")
|
||||
# options.set_auto_parallelize(True)
|
||||
# with pytest.raises(
|
||||
# RuntimeError,
|
||||
# match="call: current runtime doesn't support dataflow execution",
|
||||
# ):
|
||||
# compile_run_assert(
|
||||
# engine, mlir_input, args, expected_result, keyset_cache, options=options
|
||||
# )
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -340,17 +329,11 @@ def test_compile_dataflow_and_fail_run(
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"EngineClass",
|
||||
[
|
||||
pytest.param(JITSupport, id="JIT"),
|
||||
pytest.param(LibrarySupport, id="Library"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_loop_parallelize(
|
||||
mlir_input, args, expected_result, keyset_cache, EngineClass
|
||||
mlir_input, args, expected_result, keyset_cache
|
||||
):
|
||||
engine = EngineClass.new()
|
||||
artifact_dir = "./py_test_compile_and_run_loop_parallelize"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_loop_parallelize(True)
|
||||
compile_run_assert(
|
||||
@@ -378,17 +361,9 @@ def test_compile_and_run_loop_parallelize(
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"EngineClass",
|
||||
[
|
||||
pytest.param(JITSupport, id="JIT"),
|
||||
pytest.param(LibrarySupport, id="Library"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_invalid_arg_number(
|
||||
mlir_input, args, EngineClass, keyset_cache
|
||||
):
|
||||
engine = EngineClass.new()
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
|
||||
artifact_dir = "./py_test_compile_and_run_invalid_arg_number"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
|
||||
):
|
||||
@@ -417,7 +392,8 @@ def test_compile_and_run_invalid_arg_number(
|
||||
],
|
||||
)
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = JITSupport.new()
|
||||
artifact_dir = "./py_test_compile_invalid"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
|
||||
engine.compile(mlir_input)
|
||||
|
||||
@@ -433,7 +409,8 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
|
||||
"""
|
||||
|
||||
engine = JITSupport.new()
|
||||
artifact_dir = "./py_test_crt_decomposition_feedback"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ module {
|
||||
support = LibrarySupport.new(str(tmpdirname))
|
||||
compilation_result = support.compile(mlir)
|
||||
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
server_lambda = support.load_server_lambda(compilation_result, False)
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
|
||||
keyset = ClientSupport.key_set(client_parameters)
|
||||
|
||||
@@ -37,7 +37,7 @@ def run_simulated(engine, args_and_shape, compilation_result):
|
||||
values.append(sim_value_exporter.export_tensor(pos, arg, shape))
|
||||
pos += 1
|
||||
public_arguments = PublicArguments.new(client_parameters, values)
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
server_lambda = engine.load_server_lambda(compilation_result, True)
|
||||
public_result = engine.simulate(server_lambda, public_arguments)
|
||||
sim_value_decrypter = SimulatedValueDecrypter.new(client_parameters)
|
||||
result = sim_value_decrypter.decrypt(0, public_result.get_value(0))
|
||||
|
||||
@@ -3,9 +3,6 @@ from concrete.compiler import (
|
||||
ClientParameters,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
KeySetCache,
|
||||
KeySet,
|
||||
LambdaArgument,
|
||||
@@ -24,9 +21,6 @@ from concrete.compiler import (
|
||||
pytest.param(ClientParameters, id="ClientParameters"),
|
||||
pytest.param(ClientSupport, id="ClientSupport"),
|
||||
pytest.param(CompilationOptions, id="CompilationOptions"),
|
||||
pytest.param(JITCompilationResult, id="JITCompilationResult"),
|
||||
pytest.param(JITLambda, id="JITLambda"),
|
||||
pytest.param(JITSupport, id="JITSupport"),
|
||||
pytest.param(KeySetCache, id="KeySetCache"),
|
||||
pytest.param(KeySet, id="KeySet"),
|
||||
pytest.param(LambdaArgument, id="LambdaArgument"),
|
||||
|
||||
Reference in New Issue
Block a user