From 51ae3a18676f7d8b0f15fa4937cfca7edd3a28b5 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 30 May 2022 14:46:28 +0200 Subject: [PATCH] feat: make evaluation keys explicit --- concrete/numpy/__init__.py | 2 +- concrete/numpy/compilation/circuit.py | 3 ++- concrete/numpy/compilation/client.py | 31 +++++++++++++++------------ concrete/numpy/compilation/server.py | 22 +++++-------------- tests/compilation/test_circuit.py | 15 +++++++++---- 5 files changed, 36 insertions(+), 37 deletions(-) diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 2dd37e127..a22797266 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -2,7 +2,7 @@ Export everything that users might need. """ -from concrete.compiler import PublicArguments, PublicResult +from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult from .compilation import ( Circuit, diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 83e9957a0..3cdd68317 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -138,7 +138,8 @@ class Circuit: if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `run` method") - return self.server.run(args) + self.keygen(force=False) + return self.server.run(args, self.client.evaluation_keys) def decrypt( self, diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index 8d2505076..5876f87d2 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -9,7 +9,14 @@ from pathlib import Path from typing import List, Optional, Tuple, Union import numpy as np -from concrete.compiler import ClientSupport, KeySet, KeySetCache, PublicArguments, PublicResult +from concrete.compiler import ( + ClientSupport, + EvaluationKeys, + KeySet, + KeySetCache, + PublicArguments, + PublicResult, +) from ..dtypes.integer import SignedInteger, UnsignedInteger from ..internal.utils import assert_that @@ -209,21 +216,17 @@ class Client: return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) - def unserialize_and_decrypt( - self, - result: bytes, - ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + @property + def evaluation_keys(self) -> EvaluationKeys: """ - Decrypt serialized result of homomorphic evaluaton. - - Args: - result (bytes): - serialized encrypted result of homomorphic evaluaton + Get evaluation keys for encrypted computation. Returns: - Union[int, numpy.ndarray]: - clear result of homomorphic evaluaton + EvaluationKeys + evaluation keys for encrypted computation """ - unserialized_result = self.specs.unserialize_public_result(result) - return self.decrypt(unserialized_result) + self.keygen(force=False) + + assert self._keyset is not None + return self._keyset.get_evaluation_keys() diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index e69b52c32..740b51128 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -9,6 +9,7 @@ from typing import List, Optional, Union from concrete.compiler import ( CompilationOptions, + EvaluationKeys, JITCompilationResult, JITLambda, JITSupport, @@ -158,7 +159,7 @@ class Server: return Server(client_specs, output_dir, support, compilation_result, server_lambda) - def run(self, args: PublicArguments) -> PublicResult: + def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult: """ Evaluate using encrypted arguments. @@ -166,28 +167,15 @@ class Server: args (PublicArguments): encrypted arguments of the computation - Returns: - PublicResult: - encrypted result of the computation - """ - - return self._support.server_call(self._server_lambda, args) - - def unserialize_and_run(self, args: bytes) -> PublicResult: - """ - Evaluate using serialized encrypted arguments. - - Args: - args (bytes): - serialized encrypted arguments of the computation + evaluation_keys (EvaluationKeys): + evaluation keys for encrypted computation Returns: PublicResult: encrypted result of the computation """ - unserialized_args = self.client_specs.unserialize_public_args(args) - return self.run(unserialized_args) + return self._support.server_call(self._server_lambda, args, evaluation_keys) def cleanup(self): """ diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 0aa7af232..0f5f48fe4 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -8,7 +8,7 @@ from pathlib import Path import numpy as np import pytest -from concrete.numpy import Client, ClientSpecs, Server +from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server from concrete.numpy.compilation import compiler @@ -207,12 +207,19 @@ def test_client_server_api(helpers): for client in clients: args = client.encrypt(4) - serialized_args = client.specs.serialize_public_args(args) - result = server.unserialize_and_run(serialized_args) + serialized_args = client.specs.serialize_public_args(args) + serialized_evaluation_keys = client.evaluation_keys.serialize() + + unserialized_args = server.client_specs.unserialize_public_args(serialized_args) + unserialized_evaluation_keys = EvaluationKeys.unserialize(serialized_evaluation_keys) + + result = server.run(unserialized_args, unserialized_evaluation_keys) serialized_result = server.client_specs.serialize_public_result(result) - output = client.unserialize_and_decrypt(serialized_result) + unserialized_result = client.specs.unserialize_public_result(serialized_result) + output = client.decrypt(unserialized_result) + assert output == 46 server.cleanup()