mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 08:01:20 -05:00
feat(compiler): Add support for multi output function up to python bindings
This commit is contained in:
committed by
Quentin Bourgerie
parent
d4cc79f10d
commit
09af803754
@@ -92,7 +92,7 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<lambdaArgument>
|
||||
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
@@ -165,7 +165,7 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
std::move(publicArgs));
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<lambdaArgument>
|
||||
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult) {
|
||||
@@ -176,23 +176,24 @@ decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
if (publicResult.values.size() != 1) {
|
||||
throw std::runtime_error("Tried to decrypt with wrong arity.");
|
||||
}
|
||||
auto circuit = maybeProgram.value()
|
||||
.getClientCircuit(clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getName())
|
||||
.value();
|
||||
auto maybeProcessed = circuit.processOutput(publicResult.values[0], 0);
|
||||
if (maybeProcessed.has_failure()) {
|
||||
throw std::runtime_error(maybeProcessed.as_failure().error().mesg);
|
||||
}
|
||||
std::vector<lambdaArgument> results;
|
||||
for (auto e : llvm::enumerate(publicResult.values)) {
|
||||
auto maybeProcessed = circuit.processOutput(e.value(), e.index());
|
||||
if (maybeProcessed.has_failure()) {
|
||||
throw std::runtime_error(maybeProcessed.as_failure().error().mesg);
|
||||
}
|
||||
|
||||
mlir::concretelang::LambdaArgument out{maybeProcessed.value()};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
mlir::concretelang::LambdaArgument out{maybeProcessed.value()};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
results.push_back(tensor_arg);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
|
||||
@@ -175,30 +175,34 @@ class ClientSupport(WrapperCpp):
|
||||
raise TypeError(
|
||||
f"public_result must be of type PublicResult, not {type(public_result)}"
|
||||
)
|
||||
lambda_arg = LambdaArgument.wrap(
|
||||
_ClientSupport.decrypt_result(
|
||||
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
|
||||
)
|
||||
results = _ClientSupport.decrypt_result(
|
||||
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
|
||||
)
|
||||
|
||||
output_signs = client_parameters.output_signs()
|
||||
assert len(output_signs) == 1
|
||||
def process_result(result):
|
||||
lambda_arg = LambdaArgument.wrap(result)
|
||||
is_signed = lambda_arg.is_signed()
|
||||
if lambda_arg.is_scalar():
|
||||
return (
|
||||
lambda_arg.get_signed_scalar()
|
||||
if is_signed
|
||||
else lambda_arg.get_scalar()
|
||||
)
|
||||
|
||||
is_signed = lambda_arg.is_signed()
|
||||
if lambda_arg.is_scalar():
|
||||
return (
|
||||
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
|
||||
)
|
||||
if lambda_arg.is_tensor():
|
||||
return np.array(
|
||||
lambda_arg.get_signed_tensor_data()
|
||||
if is_signed
|
||||
else lambda_arg.get_tensor_data(),
|
||||
dtype=(np.int64 if is_signed else np.uint64),
|
||||
).reshape(lambda_arg.get_tensor_shape())
|
||||
|
||||
if lambda_arg.is_tensor():
|
||||
return np.array(
|
||||
lambda_arg.get_signed_tensor_data()
|
||||
if is_signed
|
||||
else lambda_arg.get_tensor_data(),
|
||||
dtype=(np.int64 if is_signed else np.uint64),
|
||||
).reshape(lambda_arg.get_tensor_shape())
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
raise RuntimeError("unknown return type")
|
||||
processed_results = tuple(map(process_result, results))
|
||||
if len(processed_results) == 1:
|
||||
return processed_results[0]
|
||||
return processed_results
|
||||
|
||||
@staticmethod
|
||||
def _create_lambda_argument(
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
description: zero
|
||||
program: |
|
||||
func.func @main() -> (tensor<2x!FHE.eint<7>>, !FHE.eint<2>) {
|
||||
%0 = "FHE.zero_tensor"(): () -> tensor<2x!FHE.eint<7>>
|
||||
%1 = "FHE.zero"(): () -> !FHE.eint<2>
|
||||
return %0, %1 : tensor<2x!FHE.eint<7>>, !FHE.eint<2>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- outputs:
|
||||
- tensor: [0, 0]
|
||||
shape: [2]
|
||||
- scalar: 0
|
||||
---
|
||||
description: identity_mono_precision
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<3>, %arg2: !FHE.eint<3>, %arg3: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>) {
|
||||
return %arg0, %arg1, %arg2, %arg3 : !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>, !FHE.eint<3>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 2
|
||||
- scalar: 3
|
||||
- scalar: 6
|
||||
outputs:
|
||||
- scalar: 7
|
||||
- scalar: 2
|
||||
- scalar: 3
|
||||
- scalar: 6
|
||||
---
|
||||
description: apply_lookup_table_multi_precision
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<5>, %arg3: !FHE.eint<6>) -> (!FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>) {
|
||||
%lut0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
|
||||
%bs0 = "FHE.apply_lookup_table"(%arg0, %lut0): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
|
||||
|
||||
%lut1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%bs1 = "FHE.apply_lookup_table"(%arg1, %lut1): (!FHE.eint<4>, tensor<16xi64>) -> (!FHE.eint<4>)
|
||||
|
||||
%lut3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64>
|
||||
%bs3 = "FHE.apply_lookup_table"(%arg2, %lut3): (!FHE.eint<5>, tensor<32xi64>) -> (!FHE.eint<5>)
|
||||
|
||||
%lut4 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
|
||||
%bs4 = "FHE.apply_lookup_table"(%arg3, %lut4): (!FHE.eint<6>, tensor<64xi64>) -> (!FHE.eint<6>)
|
||||
|
||||
return %bs0, %bs1, %bs3, %bs4 : !FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 15
|
||||
- scalar: 31
|
||||
- scalar: 63
|
||||
outputs:
|
||||
- scalar: 7
|
||||
- scalar: 15
|
||||
- scalar: 31
|
||||
- scalar: 63
|
||||
@@ -75,6 +75,28 @@ func.func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> ten
|
||||
np.array([8, 2, 4, 9]),
|
||||
id="enc_enc_ndarray_args",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func.func @main(%arg0: !FHE.eint<3>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<5>, %arg3: !FHE.eint<6>) -> (!FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>) {
|
||||
%lut0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
|
||||
%bs0 = "FHE.apply_lookup_table"(%arg0, %lut0): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
|
||||
|
||||
%lut1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%bs1 = "FHE.apply_lookup_table"(%arg1, %lut1): (!FHE.eint<4>, tensor<16xi64>) -> (!FHE.eint<4>)
|
||||
|
||||
%lut3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64>
|
||||
%bs3 = "FHE.apply_lookup_table"(%arg2, %lut3): (!FHE.eint<5>, tensor<32xi64>) -> (!FHE.eint<5>)
|
||||
|
||||
%lut4 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
|
||||
%bs4 = "FHE.apply_lookup_table"(%arg3, %lut4): (!FHE.eint<6>, tensor<64xi64>) -> (!FHE.eint<6>)
|
||||
|
||||
return %bs0, %bs1, %bs3, %bs4 : !FHE.eint<3>, !FHE.eint<4>, !FHE.eint<5>, !FHE.eint<6>
|
||||
}
|
||||
""",
|
||||
(7, 15, 31, 63),
|
||||
(7, 15, 31, 63),
|
||||
id="apply_lookup_table_multi_ouput",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):
|
||||
|
||||
Reference in New Issue
Block a user