feat: support passing plain arrays to encrypt

This commit is contained in:
Umut
2022-08-12 17:01:03 +02:00
parent 6c6e657b6e
commit 09dde3013e
2 changed files with 6 additions and 3 deletions

View File

@@ -128,6 +128,9 @@ class Client:
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
for index, spec in enumerate(input_specs):
arg = args[index]
if isinstance(arg, list):
arg = np.array(arg)
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)

View File

@@ -178,7 +178,7 @@ def test_client_server_api(helpers):
def function(x):
return x + 42
inputset = range(10)
inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)]
circuit = function.compile(inputset, configuration.fork(jit=False))
# for coverage
@@ -206,7 +206,7 @@ def test_client_server_api(helpers):
]
for client in clients:
args = client.encrypt(4)
args = client.encrypt([3, 8, 1])
serialized_args = client.specs.serialize_public_args(args)
serialized_evaluation_keys = client.evaluation_keys.serialize()
@@ -220,7 +220,7 @@ def test_client_server_api(helpers):
unserialized_result = client.specs.unserialize_public_result(serialized_result)
output = client.decrypt(unserialized_result)
assert output == 46
assert np.array_equal(output, [45, 50, 43])
server.cleanup()