feat(compiler): Add support for multi output function up to python bindings

This commit is contained in:
Bourgerie Quentin
2023-11-09 19:01:33 +01:00
committed by Quentin Bourgerie
parent d4cc79f10d
commit 09af803754
5 changed files with 118 additions and 32 deletions

View File

@@ -92,7 +92,7 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
MLIR_CAPI_EXPORTED lambdaArgument
MLIR_CAPI_EXPORTED std::vector<lambdaArgument>
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult);

View File

@@ -165,7 +165,7 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
std::move(publicArgs));
}
MLIR_CAPI_EXPORTED lambdaArgument
MLIR_CAPI_EXPORTED std::vector<lambdaArgument>
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult) {
@@ -176,23 +176,24 @@ decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
if (publicResult.values.size() != 1) {
throw std::runtime_error("Tried to decrypt with wrong arity.");
}
auto circuit = maybeProgram.value()
.getClientCircuit(clientParameters.programInfo.asReader()
.getCircuits()[0]
.getName())
.value();
auto maybeProcessed = circuit.processOutput(publicResult.values[0], 0);
if (maybeProcessed.has_failure()) {
throw std::runtime_error(maybeProcessed.as_failure().error().mesg);
}
std::vector<lambdaArgument> results;
for (auto e : llvm::enumerate(publicResult.values)) {
auto maybeProcessed = circuit.processOutput(e.value(), e.index());
if (maybeProcessed.has_failure()) {
throw std::runtime_error(maybeProcessed.as_failure().error().mesg);
}
mlir::concretelang::LambdaArgument out{maybeProcessed.value()};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
mlir::concretelang::LambdaArgument out{maybeProcessed.value()};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
results.push_back(tensor_arg);
}
return results;
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>

View File

@@ -175,30 +175,34 @@ class ClientSupport(WrapperCpp):
raise TypeError(
f"public_result must be of type PublicResult, not {type(public_result)}"
)
lambda_arg = LambdaArgument.wrap(
_ClientSupport.decrypt_result(
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
)
results = _ClientSupport.decrypt_result(
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
)
output_signs = client_parameters.output_signs()
assert len(output_signs) == 1
def process_result(result):
lambda_arg = LambdaArgument.wrap(result)
is_signed = lambda_arg.is_signed()
if lambda_arg.is_scalar():
return (
lambda_arg.get_signed_scalar()
if is_signed
else lambda_arg.get_scalar()
)
is_signed = lambda_arg.is_signed()
if lambda_arg.is_scalar():
return (
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
)
if lambda_arg.is_tensor():
return np.array(
lambda_arg.get_signed_tensor_data()
if is_signed
else lambda_arg.get_tensor_data(),
dtype=(np.int64 if is_signed else np.uint64),
).reshape(lambda_arg.get_tensor_shape())
if lambda_arg.is_tensor():
return np.array(
lambda_arg.get_signed_tensor_data()
if is_signed
else lambda_arg.get_tensor_data(),
dtype=(np.int64 if is_signed else np.uint64),
).reshape(lambda_arg.get_tensor_shape())
raise RuntimeError("unknown return type")
raise RuntimeError("unknown return type")
processed_results = tuple(map(process_result, results))
if len(processed_results) == 1:
return processed_results[0]
return processed_results
@staticmethod
def _create_lambda_argument(

View File

@@ -0,0 +1,59 @@
description: zero
program: |
func.func @main() -> (tensor<2x!FHE.eint<7>>, !FHE.eint<2>) {
%0 = "FHE.zero_tensor"(): () -> tensor<2x!FHE.eint<7>>
%1 = "FHE.zero"(): () -> !FHE.eint<2>
return %0, %1 : tensor<2x!FHE.eint<7>>, !FHE.eint<2>
}
tests:
- inputs:
- outputs:
- tensor: [0, 0]
shape: [2]
- scalar: 0
---
description: identity_mono_precision
program: |
func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>, %arg2: !FHE.eint<3>, %arg3: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>) {
return %arg0, %arg1, %arg2, %arg3 : !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>
}
tests:
- inputs:
- scalar: 7
- scalar: 2
- scalar: 3
- scalar: 6
outputs:
- scalar: 7
- scalar: 2
- scalar: 3
- scalar: 6
---
description: apply_lookup_table_multi_precision
program: |
func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<5>, %arg3: !FHE.eint<6>) -> (!FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>) {
%lut0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
%bs0 = "FHE.apply_lookup_table"(%arg0, %lut0): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
%lut1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
%bs1 = "FHE.apply_lookup_table"(%arg1, %lut1): (!FHE.eint<4>, tensor<16xi64>) -> (!FHE.eint<4>)
%lut3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64>
%bs3 = "FHE.apply_lookup_table"(%arg2, %lut3): (!FHE.eint<5>, tensor<32xi64>) -> (!FHE.eint<5>)
%lut4 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
%bs4 = "FHE.apply_lookup_table"(%arg3, %lut4): (!FHE.eint<6>, tensor<64xi64>) -> (!FHE.eint<6>)
return %bs0, %bs1, %bs3, %bs4 : !FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>
}
tests:
- inputs:
- scalar: 7
- scalar: 15
- scalar: 31
- scalar: 63
outputs:
- scalar: 7
- scalar: 15
- scalar: 31
- scalar: 63

View File

@@ -75,6 +75,28 @@ func.func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> ten
np.array([8, 2, 4, 9]),
id="enc_enc_ndarray_args",
),
pytest.param(
"""
func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<5>, %arg3: !FHE.eint<6>) -> (!FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>) {
%lut0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
%bs0 = "FHE.apply_lookup_table"(%arg0, %lut0): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
%lut1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
%bs1 = "FHE.apply_lookup_table"(%arg1, %lut1): (!FHE.eint<4>, tensor<16xi64>) -> (!FHE.eint<4>)
%lut3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64>
%bs3 = "FHE.apply_lookup_table"(%arg2, %lut3): (!FHE.eint<5>, tensor<32xi64>) -> (!FHE.eint<5>)
%lut4 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
%bs4 = "FHE.apply_lookup_table"(%arg3, %lut4): (!FHE.eint<6>, tensor<64xi64>) -> (!FHE.eint<6>)
return %bs0, %bs1, %bs3, %bs4 : !FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>
}
""",
(7, 15, 31, 63),
(7, 15, 31, 63),
id="apply_lookup_table_multi_ouput",
),
],
)
def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):