feat: implement client server architecture

This commit is contained in:
Umut
2022-05-12 09:53:14 +02:00
parent 2c689a3238
commit 90c95e380c
8 changed files with 559 additions and 382 deletions

View File

@@ -4,10 +4,13 @@ Export everything that users might need.
from .compilation import (
Circuit,
Client,
ClientSpecs,
Compiler,
Configuration,
DebugArtifacts,
EncryptionStatus,
Server,
compiler,
)
from .extensions import LookupTable

View File

@@ -4,6 +4,9 @@ Glue the compilation process together.
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .client import Client
from .compiler import Compiler, EncryptionStatus
from .configuration import Configuration
from .decorator import compiler
from .server import Server
from .specs import ClientSpecs

View File

@@ -2,267 +2,60 @@
Declaration of `Circuit` class.
"""
import pickle
import shutil
import tempfile
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union, cast
from typing import Any, Optional, Tuple, Union, cast
import numpy as np
from concrete.compiler import (
ClientParameters,
ClientSupport,
CompilationOptions,
JITCompilationResult,
JITLambda,
JITSupport,
KeySet,
KeySetCache,
LibraryCompilationResult,
LibraryLambda,
LibrarySupport,
PublicArguments,
PublicResult,
)
from concrete.compiler import PublicArguments, PublicResult
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Graph
from ..values import Value
from .client import Client
from .configuration import Configuration
from .server import Server
class Circuit:
"""
Circuit class, to combine computation graph and compiler engine into a single object.
Circuit class, to combine computation graph, mlir, client and server into a single object.
"""
# pylint: disable=too-many-instance-attributes
configuration: Configuration
graph: Graph
mlir: str
client_parameters: Optional[ClientParameters]
_support: Union[JITSupport, LibrarySupport]
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
_server_lambda: Union[JITLambda, LibraryLambda]
client: Client
server: Server
_output_dir: Optional[tempfile.TemporaryDirectory]
_keyset: KeySet
_keyset_cache: KeySetCache
# pylint: enable=too-many-instance-attributes
def __init__(
self,
configuration: Configuration,
graph: Graph,
mlir: str,
support: Optional[Union[JITSupport, LibrarySupport]] = None,
compilation_result: Optional[Union[JITCompilationResult, LibraryCompilationResult]] = None,
server_lambda: Optional[Union[JITLambda, LibraryLambda]] = None,
output_dir: Optional[tempfile.TemporaryDirectory] = None,
):
self.configuration = configuration
self._output_dir = output_dir
def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
self.configuration = configuration if configuration is not None else Configuration()
self.graph = graph
self.mlir = mlir
self.client_parameters = None
if configuration.virtual:
assert_that(configuration.enable_unsafe_features)
if self.configuration.virtual:
assert_that(self.configuration.enable_unsafe_features)
return
assert support is not None
assert compilation_result is not None
assert server_lambda is not None
output_signs = []
for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate
output_value = graph.output_nodes[i].output
assert_that(isinstance(output_value.dtype, Integer))
output_dtype = cast(Integer, output_value.dtype)
output_signs.append(output_dtype.is_signed)
assert_that(
(
isinstance(support, JITSupport)
and isinstance(compilation_result, JITCompilationResult)
and isinstance(server_lambda, JITLambda)
)
or (
isinstance(support, LibrarySupport)
and isinstance(compilation_result, LibraryCompilationResult)
and isinstance(server_lambda, LibraryLambda)
)
)
self.server = Server.create(mlir, output_signs, self.configuration)
self._support = support
self._compilation_result = compilation_result
self._server_lambda = server_lambda
self._output_dir = output_dir
if isinstance(support, LibrarySupport):
assert output_dir is not None
assert_that(support.output_dir_path == str(output_dir.name))
self.client_parameters = support.load_client_parameters(compilation_result)
keyset = None
keyset_cache = None
if configuration.use_insecure_key_cache:
assert_that(configuration.enable_unsafe_features)
if self.configuration.use_insecure_key_cache:
assert_that(self.configuration.enable_unsafe_features)
location = Configuration.insecure_key_cache_location()
if location is not None:
keyset_cache = KeySetCache.new(str(location))
keyset_cache = str(location)
self._keyset = keyset
self._keyset_cache = keyset_cache
@staticmethod
def create(graph: Graph, mlir: str, configuration: Optional[Configuration] = None) -> "Circuit":
"""
Create a circuit from a graph and its MLIR.
Args:
graph (Graph):
graph of the circuit
mlir (str):
mlir of the circuit
configuration (Optional[Configuration], default = None):
configuration to use
Returns:
Circuit:
circuit of graph
"""
configuration = configuration if configuration is not None else Configuration()
if configuration.virtual:
return Circuit(configuration, graph, mlir)
options = CompilationOptions.new("main")
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
options.set_auto_parallelize(configuration.auto_parallelize)
options.set_p_error(configuration.p_error)
if configuration.jit:
output_dir = None
support = JITSupport.new()
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
else:
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
support = LibrarySupport.new(
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
)
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
return Circuit(
configuration,
graph,
mlir,
support,
compilation_result,
server_lambda,
output_dir,
)
def save(self, path: Union[str, Path]):
"""
Save the circuit into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the circuit
"""
if not self.configuration.virtual and self.configuration.jit:
raise RuntimeError("JIT Circuits cannot be saved")
if self.configuration.virtual:
# pylint: disable=consider-using-with
self._output_dir = tempfile.TemporaryDirectory()
# pylint: enable=consider-using-with
assert self._output_dir is not None
output_dir_path = Path(self._output_dir.name)
with open(output_dir_path / "out.pickle", "wb") as f:
attributes = {
"configuration": self.configuration,
"graph": self.graph,
"mlir": self.mlir,
}
pickle.dump(attributes, f)
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
shutil.make_archive(path, "zip", str(output_dir_path))
if self.configuration.virtual:
self.cleanup()
self._output_dir = None
@staticmethod
def load(path: Union[str, Path]) -> "Circuit":
"""
Load the circuit from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the circuit from
Returns:
Circuit:
circuit loaded from the filesystem
"""
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
shutil.unpack_archive(path, str(output_dir_path), "zip")
with open(output_dir_path / "out.pickle", "rb") as f:
attributes = pickle.load(f)
configuration = attributes["configuration"]
graph = attributes["graph"]
mlir = attributes["mlir"]
if configuration.virtual:
output_dir.cleanup()
return Circuit(configuration, graph, mlir)
support = LibrarySupport.new(
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
)
compilation_result = support.reload("main")
server_lambda = support.load_server_lambda(compilation_result)
return Circuit(
configuration,
graph,
mlir,
support,
compilation_result,
server_lambda,
output_dir,
)
self.client = Client(self.server.client_specs, keyset_cache)
def __str__(self):
return self.graph.format()
@@ -310,8 +103,7 @@ class Circuit:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `keygen` method")
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self.client_parameters, self._keyset_cache)
self.client.keygen(force)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
@@ -329,51 +121,7 @@ class Circuit:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
if len(args) != len(self.graph.input_nodes):
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
sanitized_args = {}
for index, node in self.graph.input_nodes.items():
arg = args[index]
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)
expected_value = node.output
assert_that(isinstance(expected_value.dtype, Integer))
expected_dtype = cast(Integer, expected_value.dtype)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
expected_shape = expected_value.shape
actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape
is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_shape
)
if is_valid:
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=expected_value.is_encrypted)
raise ValueError(
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
self.keygen(force=False)
return ClientSupport.encrypt_arguments(
self.client_parameters,
self._keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
return self.client.encrypt(*args)
def run(self, args: PublicArguments) -> PublicResult:
"""
@@ -391,7 +139,7 @@ class Circuit:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `run` method")
return self._support.server_call(self._server_lambda, args)
return self.server.run(args)
def decrypt(
self,
@@ -412,34 +160,7 @@ class Circuit:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
results = ClientSupport.decrypt_result(self._keyset, result)
if not isinstance(results, tuple):
results = (results,)
sanitized_results: List[Union[int, np.ndarray]] = []
for index, node in self.graph.output_nodes.items():
expected_value = node.output
assert_that(isinstance(expected_value.dtype, Integer))
expected_dtype = cast(Integer, expected_value.dtype)
n = expected_dtype.bit_width
result = results[index] % (2**n)
if expected_dtype.is_signed:
if isinstance(result, int):
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
sanitized_results.append(sanititzed_result)
else:
result = result.astype(np.longlong) # to prevent overflows in numpy
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
sanitized_results.append(sanititzed_result.astype(np.int8))
else:
sanitized_results.append(
result if isinstance(result, int) else result.astype(np.uint8)
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
return self.client.decrypt(result)
def encrypt_run_decrypt(self, *args: Any) -> Any:
"""
@@ -464,5 +185,4 @@ class Circuit:
Cleanup the temporary library output directory.
"""
if self._output_dir is not None:
self._output_dir.cleanup()
self.server.cleanup()

View File

@@ -0,0 +1,210 @@
"""
Declaration of `Client` class.
"""
import json
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import ClientSupport, KeySet, KeySetCache, PublicArguments, PublicResult
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
from ..values.value import Value
from .specs import ClientSpecs
class Client:
"""
Client class, which can be used to manage keys, encrypt arguments and decrypt results.
"""
specs: ClientSpecs
_keyset: Optional[KeySet]
_keyset_cache: Optional[KeySetCache]
def __init__(
self,
client_specs: ClientSpecs,
keyset_cache_directory: Optional[Union[str, Path]] = None,
):
self.specs = client_specs
self._keyset = None
self._keyset_cache = None
if keyset_cache_directory is not None:
self._keyset_cache = KeySetCache.new(str(keyset_cache_directory))
def save(self, path: Union[str, Path]):
"""
Save the client into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the client
"""
with tempfile.TemporaryDirectory() as tmp_dir:
with open(Path(tmp_dir) / "client.specs.json", "w", encoding="utf-8") as f:
f.write(self.specs.serialize())
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
shutil.make_archive(path, "zip", tmp_dir)
@staticmethod
def load(
path: Union[str, Path],
keyset_cache_directory: Optional[Union[str, Path]] = None,
) -> "Client":
"""
Load the client from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the client from
keyset_cache_directory (Optional[Union[str, Path]], default = None):
keyset cache directory to use
Returns:
Client:
client loaded from the filesystem
"""
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.unpack_archive(path, tmp_dir, "zip")
with open(Path(tmp_dir) / "client.specs.json", "r", encoding="utf-8") as f:
client_specs = ClientSpecs.unserialize(f.read())
return Client(client_specs, keyset_cache_directory)
def keygen(self, force: bool = False):
"""
Generate keys required for homomorphic evaluation.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated
"""
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self.specs.client_parameters, self._keyset_cache)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
Prepare inputs to be run on the circuit.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
Returns:
PublicArguments:
encrypted and plain arguments as well as public keys
"""
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("inputs" in client_parameters_json)
input_specs = client_parameters_json["inputs"]
if len(args) != len(input_specs):
raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}")
sanitized_args = {}
for index, spec in enumerate(input_specs):
arg = args[index]
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)
width = spec["shape"]["width"]
shape = tuple(spec["shape"]["dimensions"])
is_encrypted = spec["encryption"] is not None
expected_dtype = UnsignedInteger(width)
expected_value = Value(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape
is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_value.shape
)
if is_valid:
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)
raise ValueError(
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
self.keygen(force=False)
return ClientSupport.encrypt_arguments(
self.specs.client_parameters,
self._keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
def decrypt(
self,
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Decrypt result of homomorphic evaluaton.
Args:
result (PublicResult):
encrypted result of homomorphic evaluaton
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
"""
results = ClientSupport.decrypt_result(self._keyset, result)
if not isinstance(results, tuple):
results = (results,)
sanitized_results: List[Union[int, np.ndarray]] = []
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("outputs" in client_parameters_json)
output_specs = client_parameters_json["outputs"]
for index, spec in enumerate(output_specs):
n = spec["shape"]["width"]
expected_dtype = (
SignedInteger(n) if self.specs.output_signs[index] else UnsignedInteger(n)
)
result = results[index] % (2**n)
if expected_dtype.is_signed:
if isinstance(result, int):
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
sanitized_results.append(sanititzed_result)
else:
result = result.astype(np.longlong) # to prevent overflows in numpy
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
sanitized_results.append(sanititzed_result.astype(np.int8))
else:
sanitized_results.append(
result if isinstance(result, int) else result.astype(np.uint8)
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)

View File

@@ -375,10 +375,12 @@ class Compiler:
print()
circuit = Circuit.create(self.graph, mlir, self.configuration)
circuit = Circuit(self.graph, mlir, self.configuration)
if not self.configuration.virtual:
assert circuit.client_parameters is not None
self.artifacts.add_client_parameters(circuit.client_parameters.serialize())
assert circuit.client.specs.client_parameters is not None
self.artifacts.add_client_parameters(
circuit.client.specs.client_parameters.serialize()
)
return circuit
except Exception: # pragma: no cover

View File

@@ -0,0 +1,182 @@
"""
Declaration of `Server` class.
"""
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Union
from concrete.compiler import (
CompilationOptions,
JITCompilationResult,
JITLambda,
JITSupport,
LibraryCompilationResult,
LibraryLambda,
LibrarySupport,
PublicArguments,
PublicResult,
)
from ..internal.utils import assert_that
from .configuration import Configuration
from .specs import ClientSpecs
class Server:
"""
Client class, which can be used to perform homomorphic computation.
"""
client_specs: ClientSpecs
_output_dir: Optional[tempfile.TemporaryDirectory]
_support: Union[JITSupport, LibrarySupport]
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
_server_lambda: Union[JITLambda, LibraryLambda]
def __init__(
self,
client_specs: ClientSpecs,
output_dir: Optional[tempfile.TemporaryDirectory],
support: Union[JITSupport, LibrarySupport],
compilation_result: Union[JITCompilationResult, LibraryCompilationResult],
server_lambda: Union[JITLambda, LibraryLambda],
):
self.client_specs = client_specs
self._output_dir = output_dir
self._support = support
self._compilation_result = compilation_result
self._server_lambda = server_lambda
assert_that(
support.load_client_parameters(compilation_result).serialize()
== client_specs.client_parameters.serialize()
)
@staticmethod
def create(mlir: str, output_signs: List[bool], configuration: Configuration) -> "Server":
"""
Create a server using MLIR and output sign information.
Args:
mlir (str):
mlir to compile
output_signs (List[bool]):
sign status of the outputs
configuration (Optional[Configuration], default = None):
configuration to use
"""
options = CompilationOptions.new("main")
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
options.set_auto_parallelize(configuration.auto_parallelize)
options.set_p_error(configuration.p_error)
if configuration.jit:
output_dir = None
support = JITSupport.new()
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
else:
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
support = LibrarySupport.new(
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
)
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
client_specs = ClientSpecs(support.load_client_parameters(compilation_result), output_signs)
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def save(self, path: Union[str, Path]):
"""
Save the server into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the server
"""
if self._output_dir is None:
raise RuntimeError("Just-in-Time compilation cannot be saved")
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
f.write(self.client_specs.serialize())
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
shutil.make_archive(path, "zip", self._output_dir.name)
@staticmethod
def load(path: Union[str, Path]) -> "Server":
"""
Load the server from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the server from
Returns:
Server:
server loaded from the filesystem
"""
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
shutil.unpack_archive(path, str(output_dir_path), "zip")
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
client_specs = ClientSpecs.unserialize(f.read())
support = LibrarySupport.new(
str(output_dir_path),
generateCppHeader=False,
generateStaticLib=False,
)
compilation_result = support.reload("main")
server_lambda = support.load_server_lambda(compilation_result)
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def run(self, args: PublicArguments) -> PublicResult:
"""
Evaluate using encrypted arguments.
Args:
args (PublicArguments):
encrypted arguments of the computation
Returns:
PublicResult:
encrypted result of the computation
"""
return self._support.server_call(self._server_lambda, args)
def cleanup(self):
"""
Cleanup the temporary library output directory.
"""
if self._output_dir is not None:
self._output_dir.cleanup()

View File

@@ -0,0 +1,59 @@
"""
Declaration of `ClientSpecs` class.
"""
import json
from typing import List
from concrete.compiler import ClientParameters
class ClientSpecs:
"""
ClientSpecs class, to create Client objects.
"""
client_parameters: ClientParameters
output_signs: List[bool]
def __init__(self, client_parameters: ClientParameters, output_signs: List[bool]):
self.client_parameters = client_parameters
self.output_signs = output_signs
def serialize(self) -> str:
"""
Serialize client specs into a string representation.
Returns:
str:
string representation of the client specs
"""
client_parameters_json = json.loads(self.client_parameters.serialize())
return json.dumps(
{
"client_parameters": client_parameters_json,
"output_signs": self.output_signs,
}
)
@staticmethod
def unserialize(serialized_client_specs: str) -> "ClientSpecs":
"""
Create client specs from its string representation.
Args:
serialized_client_specs (str):
client specs to unserialize
Returns:
ClientSpecs:
unserialized client specs
"""
raw_specs = json.loads(serialized_client_specs)
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
return ClientSpecs(client_parameters, raw_specs["output_signs"])

View File

@@ -7,8 +7,9 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.compiler import PublicArguments, PublicResult
from concrete.numpy import Circuit
from concrete.numpy import Client, ClientSpecs, Configuration, Server
from concrete.numpy.compilation import compiler
@@ -98,7 +99,7 @@ def test_circuit_bad_run(helpers):
circuit.encrypt_run_decrypt(-1, 11)
assert str(excinfo.value) == (
"Expected argument 0 to be EncryptedScalar<uint4> but it's EncryptedScalar<int1>"
"Expected argument 0 to be EncryptedScalar<uint6> but it's EncryptedScalar<int1>"
)
# with negative argument 1
@@ -108,7 +109,7 @@ def test_circuit_bad_run(helpers):
circuit.encrypt_run_decrypt(1, -11)
assert str(excinfo.value) == (
"Expected argument 1 to be EncryptedScalar<uint5> but it's EncryptedScalar<int5>"
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<int5>"
)
# with large argument 0
@@ -118,7 +119,7 @@ def test_circuit_bad_run(helpers):
circuit.encrypt_run_decrypt(100, 10)
assert str(excinfo.value) == (
"Expected argument 0 to be EncryptedScalar<uint4> but it's EncryptedScalar<uint7>"
"Expected argument 0 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
)
# with large argument 1
@@ -128,7 +129,7 @@ def test_circuit_bad_run(helpers):
circuit.encrypt_run_decrypt(1, 100)
assert str(excinfo.value) == (
"Expected argument 1 to be EncryptedScalar<uint5> but it's EncryptedScalar<uint7>"
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
)
@@ -167,9 +168,69 @@ def test_circuit_virtual_explicit_api(helpers):
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"
def test_circuit_bad_save(helpers):
def test_client_server_api(helpers):
"""
Test `save` method of `Circuit` class with bad parameters.
Test client/server API.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
circuit = function.compile(inputset, configuration.fork(jit=False))
# for coverage
circuit.keygen()
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path)
client_path = tmp_dir_path / "client.zip"
circuit.client.save(client_path)
circuit.cleanup()
server = Server.load(server_path)
serialized_client_specs = server.client_specs.serialize()
client_specs = ClientSpecs.unserialize(serialized_client_specs)
clients = [
Client(client_specs, Configuration.insecure_key_cache_location()),
Client.load(client_path, Configuration.insecure_key_cache_location()),
]
for client in clients:
raw_input = client.encrypt(4)
serialized_input = raw_input.serialize()
unserialized_input = PublicArguments.unserialize(
server.client_specs.client_parameters,
serialized_input,
)
evaluation = server.run(unserialized_input)
serialized_evaluation = evaluation.serialize()
unserialized_evaluation = PublicResult.unserialize(
client.specs.client_parameters,
serialized_evaluation,
)
output = client.decrypt(unserialized_evaluation)
assert output == 46
server.cleanup()
def test_bad_server_save(helpers):
"""
Test `save` method of `Server` class with bad parameters.
"""
configuration = helpers.configuration()
@@ -182,69 +243,6 @@ def test_circuit_bad_save(helpers):
circuit = function.compile(inputset, configuration)
with pytest.raises(RuntimeError) as excinfo:
circuit.save("circuit.zip")
circuit.server.save("test.zip")
assert str(excinfo.value) == "JIT Circuits cannot be saved"
@pytest.mark.parametrize(
"virtual",
[False, True],
)
def test_circuit_save_load(virtual, helpers):
"""
Test `save`, `load`, and `cleanup` methods of `Circuit` class.
"""
configuration = helpers.configuration().fork(jit=False, virtual=virtual)
def save(base):
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
circuit = function.compile(inputset, configuration)
circuit.save(base / "circuit.zip")
circuit.cleanup()
def load(base):
circuit = Circuit.load(base / "circuit.zip")
helpers.check_str(
"""
%0 = x # EncryptedScalar<uint4>
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
return %2
""",
str(circuit),
)
if virtual:
helpers.check_str("Virtual circuits doesn't have MLIR.", circuit.mlir)
else:
helpers.check_str(
"""
module {
func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c42_i7 = arith.constant 42 : i7
%0 = "FHE.add_eint_int"(%arg0, %c42_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
}
""",
circuit.mlir,
)
helpers.check_execution(circuit, lambda x: x + 42, 4)
circuit.cleanup()
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp)
save(path)
load(path)
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"