From 09dde3013e4048b087fd7af246ade3a544b215b9 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 12 Aug 2022 17:01:03 +0200 Subject: [PATCH] feat: support passing plain arrays to encrypt --- concrete/numpy/compilation/client.py | 3 +++ tests/compilation/test_circuit.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index 1dc22e87b..a48a1451c 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -128,6 +128,9 @@ class Client: sanitized_args: Dict[int, Union[int, np.ndarray]] = {} for index, spec in enumerate(input_specs): arg = args[index] + if isinstance(arg, list): + arg = np.array(arg) + is_valid = isinstance(arg, (int, np.integer)) or ( isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer) ) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 0f5f48fe4..814033bf2 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -178,7 +178,7 @@ def test_client_server_api(helpers): def function(x): return x + 42 - inputset = range(10) + inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)] circuit = function.compile(inputset, configuration.fork(jit=False)) # for coverage @@ -206,7 +206,7 @@ def test_client_server_api(helpers): ] for client in clients: - args = client.encrypt(4) + args = client.encrypt([3, 8, 1]) serialized_args = client.specs.serialize_public_args(args) serialized_evaluation_keys = client.evaluation_keys.serialize() @@ -220,7 +220,7 @@ def test_client_server_api(helpers): unserialized_result = client.specs.unserialize_public_result(serialized_result) output = client.decrypt(unserialized_result) - assert output == 46 + assert np.array_equal(output, [45, 50, 43]) server.cleanup()