From 0a5881096ca3f588dfca28e61eb5061a6798b720 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 15 Apr 2022 09:27:34 +0100 Subject: [PATCH] feat(python): serialize public arguments --- .../concretelang-c/Support/CompilerEngine.h | 10 ++ .../lib/Bindings/Python/CompilerAPIModule.cpp | 12 +- .../concrete/compiler/public_arguments.py | 36 +++++ compiler/lib/CAPI/Support/CompilerEngine.cpp | 24 +++ compiler/lib/ClientLib/PublicArguments.cpp | 2 +- compiler/tests/python/test_compilation.py | 1 - compiler/tests/python/test_serialization.py | 143 ++++++++++++++++++ 7 files changed, 225 insertions(+), 3 deletions(-) create mode 100644 compiler/tests/python/test_serialization.py diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 90512590c..f433d0d05 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -97,6 +97,16 @@ MLIR_CAPI_EXPORTED lambdaArgument decrypt_result(concretelang::clientlib::KeySet &keySet, concretelang::clientlib::PublicResult &publicResult); +// Serialization //////////////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED std::unique_ptr +publicArgumentsUnserialize( + mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( + concretelang::clientlib::PublicArguments &publicArguments); + // Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 842dbf145..49bad68ab 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -155,7 +155,17 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_(m, "ClientParameters"); pybind11::class_(m, "KeySet"); - pybind11::class_(m, "PublicArguments"); + pybind11::class_>( + m, "PublicArguments") + .def_static("unserialize", + [](mlir::concretelang::ClientParameters &clientParameters, + const pybind11::bytes &buffer) { + return publicArgumentsUnserialize(clientParameters, buffer); + }) + .def("serialize", [](clientlib::PublicArguments &publicArgument) { + return pybind11::bytes(publicArgumentsSerialize(publicArgument)); + }); pybind11::class_(m, "PublicResult"); pybind11::class_(m, "LambdaArgument") diff --git a/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py b/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py index 8c4103286..9c7db0e5a 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py @@ -9,6 +9,7 @@ from mlir._mlir_libs._concretelang._compiler import ( ) # pylint: enable=no-name-in-module,import-error +from .client_parameters import ClientParameters from .wrapper import WrapperCpp @@ -33,3 +34,38 @@ class PublicArguments(WrapperCpp): f"public_arguments must be of type _PublicArguments, not {type(public_arguments)}" ) super().__init__(public_arguments) + + def serialize(self) -> bytes: + """Serialize the PublicArguments into a buffer. + + Returns: + bytes: serialized object + """ + return self.cpp().serialize() + + @staticmethod + def unserialize( + client_parameters: ClientParameters, buffer: bytes + ) -> "PublicArguments": + """Unserialize PublicArguments from a buffer. + + Args: + client_parameters (ClientParameters): client parameters of the compiled circuit + buffer (bytes): previously serialized PublicArguments + + Raises: + TypeError: if client_parameters is not of type ClientParameters + TypeError: if buffer is not of type bytes + + Returns: + PublicArguments: unserialized object + """ + if not isinstance(client_parameters, ClientParameters): + raise TypeError( + f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" + ) + if not isinstance(buffer, bytes): + raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") + return PublicArguments.wrap( + _PublicArguments.unserialize(client_parameters.cpp(), buffer) + ) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index a441f0131..31e33ec26 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -133,6 +133,30 @@ decrypt_result(concretelang::clientlib::KeySet &keySet, return std::move(result_); } +MLIR_CAPI_EXPORTED std::unique_ptr +publicArgumentsUnserialize( + mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer) { + std::stringstream istream(buffer); + auto argsOrError = concretelang::clientlib::PublicArguments::unserialize( + clientParameters, istream); + if (!argsOrError) { + throw std::runtime_error(argsOrError.error().mesg); + } + return std::move(argsOrError.value()); +} + +MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( + concretelang::clientlib::PublicArguments &publicArguments) { + + std::ostringstream buffer(std::ios::binary); + auto voidOrError = publicArguments.serialize(buffer); + if (!voidOrError) { + throw std::runtime_error(voidOrError.error().mesg); + } + return buffer.str(); +} + void terminateParallelization() { #ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED _dfr_terminate(); diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index 553152ac5..768fa318b 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -133,7 +133,7 @@ PublicArguments::unserialize(ClientParameters &clientParameters, clientParameters, runtimeContext, true, std::move(empty), std::move(emptyBuffers)); OUTCOME_TRYV(sArguments->unserializeArgs(istream)); - return sArguments; + return std::move(sArguments); } void next_coord_index(size_t index[], size_t sizes[], size_t rank) { diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 35588cc88..92dcfc0cf 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -1,4 +1,3 @@ -from typing import Union import pytest import numpy as np from concrete.compiler import ( diff --git a/compiler/tests/python/test_serialization.py b/compiler/tests/python/test_serialization.py new file mode 100644 index 000000000..ad66f32ee --- /dev/null +++ b/compiler/tests/python/test_serialization.py @@ -0,0 +1,143 @@ +import pytest +import numpy as np +from concrete.compiler import ( + JITSupport, + LibrarySupport, + ClientSupport, + CompilationOptions, + PublicArguments, +) + + +def assert_result(result, expected_result): + """Assert that result and expected result are equal. + + result and expected_result can be integers on numpy arrays. + """ + assert type(expected_result) == type(result) + if isinstance(expected_result, int): + assert result == expected_result + else: + assert np.all(result == expected_result) + + +# TODO(#541): add result serialization +def run_with_serialization( + engine, + args, + compilation_result, + keyset_cache, +): + """Execute engine on the given arguments. Performs serialization betwee client/server. + + Perform required loading, encryption, execution, and decryption.""" + # Client + client_parameters = engine.load_client_parameters(compilation_result) + key_set = ClientSupport.key_set(client_parameters, keyset_cache) + public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) + public_arguments_buffer = public_arguments.serialize() + # Server + public_arguments = PublicArguments.unserialize( + client_parameters, public_arguments_buffer + ) + del public_arguments_buffer + server_lambda = engine.load_server_lambda(compilation_result) + public_result = engine.server_call(server_lambda, public_arguments) + # Client + result = ClientSupport.decrypt_result(key_set, public_result) + return result + + +def compile_run_assert_with_serialization( + engine, + mlir_input, + args, + expected_result, + keyset_cache, +): + """Compile run and assert result. Performs serialization betwee client/server. + + Can take both JITSupport or LibrarySupport as engine. + """ + options = CompilationOptions.new("main") + compilation_result = engine.compile(mlir_input, options) + result = run_with_serialization(engine, args, compilation_result, keyset_cache) + assert_result(result, expected_result) + + +end_to_end_fixture = [ + pytest.param( + """ + func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> + } + """, + (5, 7), + 12, + id="enc_plain_int_args", + marks=pytest.mark.xfail, + ), + pytest.param( + """ + func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> + } + """, + (5, 7), + 12, + id="enc_enc_int_args", + ), + pytest.param( + """ + func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5> + { + %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5> + return %ret : !FHE.eint<5> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([4, 3, 2, 1], dtype=np.uint8), + ), + 20, + id="enc_plain_ndarray_args", + marks=pytest.mark.xfail, + ), + pytest.param( + """ + func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> { + %res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> + return %res : tensor<4x!FHE.eint<5>> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([7, 0, 1, 5], dtype=np.uint8), + ), + np.array([8, 2, 4, 9]), + id="enc_enc_ndarray_args", + ), +] + + +@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) +def test_jit_compile_and_run_with_serialization( + mlir_input, args, expected_result, keyset_cache +): + engine = JITSupport.new() + compile_run_assert_with_serialization( + engine, mlir_input, args, expected_result, keyset_cache + ) + + +@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) +def test_lib_compile_and_run_with_serialization( + mlir_input, args, expected_result, keyset_cache +): + engine = LibrarySupport.new("./py_test_lib_compile_and_run") + compile_run_assert_with_serialization( + engine, mlir_input, args, expected_result, keyset_cache + )