mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: make evaluation keys explicit
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
Export everything that users might need.
|
||||
"""
|
||||
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
|
||||
|
||||
from .compilation import (
|
||||
Circuit,
|
||||
|
||||
@@ -138,7 +138,8 @@ class Circuit:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `run` method")
|
||||
|
||||
return self.server.run(args)
|
||||
self.keygen(force=False)
|
||||
return self.server.run(args, self.client.evaluation_keys)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
|
||||
@@ -9,7 +9,14 @@ from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import ClientSupport, KeySet, KeySetCache, PublicArguments, PublicResult
|
||||
from concrete.compiler import (
|
||||
ClientSupport,
|
||||
EvaluationKeys,
|
||||
KeySet,
|
||||
KeySetCache,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
from ..dtypes.integer import SignedInteger, UnsignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
@@ -209,21 +216,17 @@ class Client:
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
|
||||
def unserialize_and_decrypt(
|
||||
self,
|
||||
result: bytes,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
@property
|
||||
def evaluation_keys(self) -> EvaluationKeys:
|
||||
"""
|
||||
Decrypt serialized result of homomorphic evaluaton.
|
||||
|
||||
Args:
|
||||
result (bytes):
|
||||
serialized encrypted result of homomorphic evaluaton
|
||||
Get evaluation keys for encrypted computation.
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
EvaluationKeys
|
||||
evaluation keys for encrypted computation
|
||||
"""
|
||||
|
||||
unserialized_result = self.specs.unserialize_public_result(result)
|
||||
return self.decrypt(unserialized_result)
|
||||
self.keygen(force=False)
|
||||
|
||||
assert self._keyset is not None
|
||||
return self._keyset.get_evaluation_keys()
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from concrete.compiler import (
|
||||
CompilationOptions,
|
||||
EvaluationKeys,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
@@ -158,7 +159,7 @@ class Server:
|
||||
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult:
|
||||
"""
|
||||
Evaluate using encrypted arguments.
|
||||
|
||||
@@ -166,28 +167,15 @@ class Server:
|
||||
args (PublicArguments):
|
||||
encrypted arguments of the computation
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of the computation
|
||||
"""
|
||||
|
||||
return self._support.server_call(self._server_lambda, args)
|
||||
|
||||
def unserialize_and_run(self, args: bytes) -> PublicResult:
|
||||
"""
|
||||
Evaluate using serialized encrypted arguments.
|
||||
|
||||
Args:
|
||||
args (bytes):
|
||||
serialized encrypted arguments of the computation
|
||||
evaluation_keys (EvaluationKeys):
|
||||
evaluation keys for encrypted computation
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of the computation
|
||||
"""
|
||||
|
||||
unserialized_args = self.client_specs.unserialize_public_args(args)
|
||||
return self.run(unserialized_args)
|
||||
return self._support.server_call(self._server_lambda, args, evaluation_keys)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.numpy import Client, ClientSpecs, Server
|
||||
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server
|
||||
from concrete.numpy.compilation import compiler
|
||||
|
||||
|
||||
@@ -207,12 +207,19 @@ def test_client_server_api(helpers):
|
||||
|
||||
for client in clients:
|
||||
args = client.encrypt(4)
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
|
||||
result = server.unserialize_and_run(serialized_args)
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
serialized_evaluation_keys = client.evaluation_keys.serialize()
|
||||
|
||||
unserialized_args = server.client_specs.unserialize_public_args(serialized_args)
|
||||
unserialized_evaluation_keys = EvaluationKeys.unserialize(serialized_evaluation_keys)
|
||||
|
||||
result = server.run(unserialized_args, unserialized_evaluation_keys)
|
||||
serialized_result = server.client_specs.serialize_public_result(result)
|
||||
|
||||
output = client.unserialize_and_decrypt(serialized_result)
|
||||
unserialized_result = client.specs.unserialize_public_result(serialized_result)
|
||||
output = client.decrypt(unserialized_result)
|
||||
|
||||
assert output == 46
|
||||
|
||||
server.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user