diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h index a099873ae..f22becbed 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h @@ -92,7 +92,7 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, concretelang::clientlib::KeySet &keySet, llvm::ArrayRef args); -MLIR_CAPI_EXPORTED lambdaArgument +MLIR_CAPI_EXPORTED std::vector decrypt_result(concretelang::clientlib::ClientParameters clientParameters, concretelang::clientlib::KeySet &keySet, concretelang::clientlib::PublicResult &publicResult); diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp index e9413c1d0..d40c3fd93 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp @@ -165,7 +165,7 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, std::move(publicArgs)); } -MLIR_CAPI_EXPORTED lambdaArgument +MLIR_CAPI_EXPORTED std::vector decrypt_result(concretelang::clientlib::ClientParameters clientParameters, concretelang::clientlib::KeySet &keySet, concretelang::clientlib::PublicResult &publicResult) { @@ -176,23 +176,24 @@ decrypt_result(concretelang::clientlib::ClientParameters clientParameters, if (maybeProgram.has_failure()) { throw std::runtime_error(maybeProgram.as_failure().error().mesg); } - if (publicResult.values.size() != 1) { - throw std::runtime_error("Tried to decrypt with wrong arity."); - } auto circuit = maybeProgram.value() .getClientCircuit(clientParameters.programInfo.asReader() .getCircuits()[0] .getName()) .value(); - auto maybeProcessed = circuit.processOutput(publicResult.values[0], 0); - if (maybeProcessed.has_failure()) { - throw std::runtime_error(maybeProcessed.as_failure().error().mesg); - } + std::vector results; + for (auto e : llvm::enumerate(publicResult.values)) { + auto maybeProcessed = circuit.processOutput(e.value(), e.index()); + if (maybeProcessed.has_failure()) { + throw std::runtime_error(maybeProcessed.as_failure().error().mesg); + } - mlir::concretelang::LambdaArgument out{maybeProcessed.value()}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; + mlir::concretelang::LambdaArgument out{maybeProcessed.value()}; + lambdaArgument tensor_arg{ + std::make_shared(std::move(out))}; + results.push_back(tensor_arg); + } + return results; } MLIR_CAPI_EXPORTED std::unique_ptr diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py index b0184d0f3..bf90a00f8 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py @@ -175,30 +175,34 @@ class ClientSupport(WrapperCpp): raise TypeError( f"public_result must be of type PublicResult, not {type(public_result)}" ) - lambda_arg = LambdaArgument.wrap( - _ClientSupport.decrypt_result( - client_parameters.cpp(), keyset.cpp(), public_result.cpp() - ) + results = _ClientSupport.decrypt_result( + client_parameters.cpp(), keyset.cpp(), public_result.cpp() ) - output_signs = client_parameters.output_signs() - assert len(output_signs) == 1 + def process_result(result): + lambda_arg = LambdaArgument.wrap(result) + is_signed = lambda_arg.is_signed() + if lambda_arg.is_scalar(): + return ( + lambda_arg.get_signed_scalar() + if is_signed + else lambda_arg.get_scalar() + ) - is_signed = lambda_arg.is_signed() - if lambda_arg.is_scalar(): - return ( - lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar() - ) + if lambda_arg.is_tensor(): + return np.array( + lambda_arg.get_signed_tensor_data() + if is_signed + else lambda_arg.get_tensor_data(), + dtype=(np.int64 if is_signed else np.uint64), + ).reshape(lambda_arg.get_tensor_shape()) - if lambda_arg.is_tensor(): - return np.array( - lambda_arg.get_signed_tensor_data() - if is_signed - else lambda_arg.get_tensor_data(), - dtype=(np.int64 if is_signed else np.uint64), - ).reshape(lambda_arg.get_tensor_shape()) + raise RuntimeError("unknown return type") - raise RuntimeError("unknown return type") + processed_results = tuple(map(process_result, results)) + if len(processed_results) == 1: + return processed_results[0] + return processed_results @staticmethod def _create_lambda_argument( diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_multi_ouput.yaml b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_multi_ouput.yaml new file mode 100644 index 000000000..19e211012 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_multi_ouput.yaml @@ -0,0 +1,59 @@ +description: zero +program: | + func.func @main() -> (tensor<2x!FHE.eint<7>>, !FHE.eint<2>) { + %0 = "FHE.zero_tensor"(): () -> tensor<2x!FHE.eint<7>> + %1 = "FHE.zero"(): () -> !FHE.eint<2> + return %0, %1 : tensor<2x!FHE.eint<7>>, !FHE.eint<2> + } +tests: + - inputs: + - outputs: + - tensor: [0, 0] + shape: [2] + - scalar: 0 +--- +description: identity_mono_precision +program: | + func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>, %arg2: !FHE.eint<3>, %arg3: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>) { + return %arg0, %arg1, %arg2, %arg3 : !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3> + } +tests: + - inputs: + - scalar: 7 + - scalar: 2 + - scalar: 3 + - scalar: 6 + outputs: + - scalar: 7 + - scalar: 2 + - scalar: 3 + - scalar: 6 +--- +description: apply_lookup_table_multi_precision +program: | + func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<5>, %arg3: !FHE.eint<6>) -> (!FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>) { + %lut0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> + %bs0 = "FHE.apply_lookup_table"(%arg0, %lut0): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>) + + %lut1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> + %bs1 = "FHE.apply_lookup_table"(%arg1, %lut1): (!FHE.eint<4>, tensor<16xi64>) -> (!FHE.eint<4>) + + %lut3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64> + %bs3 = "FHE.apply_lookup_table"(%arg2, %lut3): (!FHE.eint<5>, tensor<32xi64>) -> (!FHE.eint<5>) + + %lut4 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> + %bs4 = "FHE.apply_lookup_table"(%arg3, %lut4): (!FHE.eint<6>, tensor<64xi64>) -> (!FHE.eint<6>) + + return %bs0, %bs1, %bs3, %bs4 : !FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6> + } +tests: + - inputs: + - scalar: 7 + - scalar: 15 + - scalar: 31 + - scalar: 63 + outputs: + - scalar: 7 + - scalar: 15 + - scalar: 31 + - scalar: 63 diff --git a/compilers/concrete-compiler/compiler/tests/python/test_client_server.py b/compilers/concrete-compiler/compiler/tests/python/test_client_server.py index 1eace918e..84072b1ac 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_client_server.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_client_server.py @@ -75,6 +75,28 @@ func.func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> ten np.array([8, 2, 4, 9]), id="enc_enc_ndarray_args", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<5>, %arg3: !FHE.eint<6>) -> (!FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>) { + %lut0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> + %bs0 = "FHE.apply_lookup_table"(%arg0, %lut0): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>) + + %lut1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> + %bs1 = "FHE.apply_lookup_table"(%arg1, %lut1): (!FHE.eint<4>, tensor<16xi64>) -> (!FHE.eint<4>) + + %lut3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64> + %bs3 = "FHE.apply_lookup_table"(%arg2, %lut3): (!FHE.eint<5>, tensor<32xi64>) -> (!FHE.eint<5>) + + %lut4 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> + %bs4 = "FHE.apply_lookup_table"(%arg3, %lut4): (!FHE.eint<6>, tensor<64xi64>) -> (!FHE.eint<6>) + + return %bs0, %bs1, %bs3, %bs4 : !FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6> + } + """, + (7, 15, 31, 63), + (7, 15, 31, 63), + id="apply_lookup_table_multi_ouput", + ), ], ) def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):