feat: make evaluation keys explicit

This commit is contained in:
Umut
2022-05-30 14:46:28 +02:00
parent 0cc32cda1c
commit 51ae3a1867
5 changed files with 36 additions and 37 deletions

View File

@@ -2,7 +2,7 @@
Export everything that users might need.
"""
from concrete.compiler import PublicArguments, PublicResult
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
from .compilation import (
Circuit,

View File

@@ -138,7 +138,8 @@ class Circuit:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `run` method")
return self.server.run(args)
self.keygen(force=False)
return self.server.run(args, self.client.evaluation_keys)
def decrypt(
self,

View File

@@ -9,7 +9,14 @@ from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import ClientSupport, KeySet, KeySetCache, PublicArguments, PublicResult
from concrete.compiler import (
ClientSupport,
EvaluationKeys,
KeySet,
KeySetCache,
PublicArguments,
PublicResult,
)
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
@@ -209,21 +216,17 @@ class Client:
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
def unserialize_and_decrypt(
self,
result: bytes,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
@property
def evaluation_keys(self) -> EvaluationKeys:
"""
Decrypt serialized result of homomorphic evaluaton.
Args:
result (bytes):
serialized encrypted result of homomorphic evaluaton
Get evaluation keys for encrypted computation.
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
EvaluationKeys
evaluation keys for encrypted computation
"""
unserialized_result = self.specs.unserialize_public_result(result)
return self.decrypt(unserialized_result)
self.keygen(force=False)
assert self._keyset is not None
return self._keyset.get_evaluation_keys()

View File

@@ -9,6 +9,7 @@ from typing import List, Optional, Union
from concrete.compiler import (
CompilationOptions,
EvaluationKeys,
JITCompilationResult,
JITLambda,
JITSupport,
@@ -158,7 +159,7 @@ class Server:
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def run(self, args: PublicArguments) -> PublicResult:
def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult:
"""
Evaluate using encrypted arguments.
@@ -166,28 +167,15 @@ class Server:
args (PublicArguments):
encrypted arguments of the computation
Returns:
PublicResult:
encrypted result of the computation
"""
return self._support.server_call(self._server_lambda, args)
def unserialize_and_run(self, args: bytes) -> PublicResult:
"""
Evaluate using serialized encrypted arguments.
Args:
args (bytes):
serialized encrypted arguments of the computation
evaluation_keys (EvaluationKeys):
evaluation keys for encrypted computation
Returns:
PublicResult:
encrypted result of the computation
"""
unserialized_args = self.client_specs.unserialize_public_args(args)
return self.run(unserialized_args)
return self._support.server_call(self._server_lambda, args, evaluation_keys)
def cleanup(self):
"""

View File

@@ -8,7 +8,7 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.numpy import Client, ClientSpecs, Server
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server
from concrete.numpy.compilation import compiler
@@ -207,12 +207,19 @@ def test_client_server_api(helpers):
for client in clients:
args = client.encrypt(4)
serialized_args = client.specs.serialize_public_args(args)
result = server.unserialize_and_run(serialized_args)
serialized_args = client.specs.serialize_public_args(args)
serialized_evaluation_keys = client.evaluation_keys.serialize()
unserialized_args = server.client_specs.unserialize_public_args(serialized_args)
unserialized_evaluation_keys = EvaluationKeys.unserialize(serialized_evaluation_keys)
result = server.run(unserialized_args, unserialized_evaluation_keys)
serialized_result = server.client_specs.serialize_public_result(result)
output = client.unserialize_and_decrypt(serialized_result)
unserialized_result = client.specs.unserialize_public_result(serialized_result)
output = client.decrypt(unserialized_result)
assert output == 46
server.cleanup()