mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement client server architecture
This commit is contained in:
@@ -4,10 +4,13 @@ Export everything that users might need.
|
||||
|
||||
from .compilation import (
|
||||
Circuit,
|
||||
Client,
|
||||
ClientSpecs,
|
||||
Compiler,
|
||||
Configuration,
|
||||
DebugArtifacts,
|
||||
EncryptionStatus,
|
||||
Server,
|
||||
compiler,
|
||||
)
|
||||
from .extensions import LookupTable
|
||||
|
||||
@@ -4,6 +4,9 @@ Glue the compilation process together.
|
||||
|
||||
from .artifacts import DebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .client import Client
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import Configuration
|
||||
from .decorator import compiler
|
||||
from .server import Server
|
||||
from .specs import ClientSpecs
|
||||
|
||||
@@ -2,267 +2,60 @@
|
||||
Declaration of `Circuit` class.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, Union, cast
|
||||
from typing import Any, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
ClientParameters,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
KeySet,
|
||||
KeySetCache,
|
||||
LibraryCompilationResult,
|
||||
LibraryLambda,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
from ..values import Value
|
||||
from .client import Client
|
||||
from .configuration import Configuration
|
||||
from .server import Server
|
||||
|
||||
|
||||
class Circuit:
|
||||
"""
|
||||
Circuit class, to combine computation graph and compiler engine into a single object.
|
||||
Circuit class, to combine computation graph, mlir, client and server into a single object.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
configuration: Configuration
|
||||
|
||||
graph: Graph
|
||||
mlir: str
|
||||
client_parameters: Optional[ClientParameters]
|
||||
|
||||
_support: Union[JITSupport, LibrarySupport]
|
||||
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
|
||||
_server_lambda: Union[JITLambda, LibraryLambda]
|
||||
client: Client
|
||||
server: Server
|
||||
|
||||
_output_dir: Optional[tempfile.TemporaryDirectory]
|
||||
|
||||
_keyset: KeySet
|
||||
_keyset_cache: KeySetCache
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configuration: Configuration,
|
||||
graph: Graph,
|
||||
mlir: str,
|
||||
support: Optional[Union[JITSupport, LibrarySupport]] = None,
|
||||
compilation_result: Optional[Union[JITCompilationResult, LibraryCompilationResult]] = None,
|
||||
server_lambda: Optional[Union[JITLambda, LibraryLambda]] = None,
|
||||
output_dir: Optional[tempfile.TemporaryDirectory] = None,
|
||||
):
|
||||
self.configuration = configuration
|
||||
self._output_dir = output_dir
|
||||
def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
|
||||
self.configuration = configuration if configuration is not None else Configuration()
|
||||
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
self.client_parameters = None
|
||||
|
||||
if configuration.virtual:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
if self.configuration.virtual:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
return
|
||||
|
||||
assert support is not None
|
||||
assert compilation_result is not None
|
||||
assert server_lambda is not None
|
||||
output_signs = []
|
||||
for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate
|
||||
output_value = graph.output_nodes[i].output
|
||||
assert_that(isinstance(output_value.dtype, Integer))
|
||||
output_dtype = cast(Integer, output_value.dtype)
|
||||
output_signs.append(output_dtype.is_signed)
|
||||
|
||||
assert_that(
|
||||
(
|
||||
isinstance(support, JITSupport)
|
||||
and isinstance(compilation_result, JITCompilationResult)
|
||||
and isinstance(server_lambda, JITLambda)
|
||||
)
|
||||
or (
|
||||
isinstance(support, LibrarySupport)
|
||||
and isinstance(compilation_result, LibraryCompilationResult)
|
||||
and isinstance(server_lambda, LibraryLambda)
|
||||
)
|
||||
)
|
||||
self.server = Server.create(mlir, output_signs, self.configuration)
|
||||
|
||||
self._support = support
|
||||
self._compilation_result = compilation_result
|
||||
self._server_lambda = server_lambda
|
||||
|
||||
self._output_dir = output_dir
|
||||
if isinstance(support, LibrarySupport):
|
||||
assert output_dir is not None
|
||||
assert_that(support.output_dir_path == str(output_dir.name))
|
||||
|
||||
self.client_parameters = support.load_client_parameters(compilation_result)
|
||||
keyset = None
|
||||
keyset_cache = None
|
||||
|
||||
if configuration.use_insecure_key_cache:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
location = Configuration.insecure_key_cache_location()
|
||||
if location is not None:
|
||||
keyset_cache = KeySetCache.new(str(location))
|
||||
keyset_cache = str(location)
|
||||
|
||||
self._keyset = keyset
|
||||
self._keyset_cache = keyset_cache
|
||||
|
||||
@staticmethod
|
||||
def create(graph: Graph, mlir: str, configuration: Optional[Configuration] = None) -> "Circuit":
|
||||
"""
|
||||
Create a circuit from a graph and its MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph of the circuit
|
||||
|
||||
mlir (str):
|
||||
mlir of the circuit
|
||||
|
||||
configuration (Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
circuit of graph
|
||||
"""
|
||||
|
||||
configuration = configuration if configuration is not None else Configuration()
|
||||
if configuration.virtual:
|
||||
return Circuit(configuration, graph, mlir)
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
|
||||
options.set_auto_parallelize(configuration.auto_parallelize)
|
||||
options.set_p_error(configuration.p_error)
|
||||
|
||||
if configuration.jit:
|
||||
|
||||
output_dir = None
|
||||
|
||||
support = JITSupport.new()
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
else:
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
support = LibrarySupport.new(
|
||||
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
|
||||
)
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Circuit(
|
||||
configuration,
|
||||
graph,
|
||||
mlir,
|
||||
support,
|
||||
compilation_result,
|
||||
server_lambda,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
"""
|
||||
Save the circuit into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the circuit
|
||||
"""
|
||||
|
||||
if not self.configuration.virtual and self.configuration.jit:
|
||||
raise RuntimeError("JIT Circuits cannot be saved")
|
||||
|
||||
if self.configuration.virtual:
|
||||
# pylint: disable=consider-using-with
|
||||
self._output_dir = tempfile.TemporaryDirectory()
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
assert self._output_dir is not None
|
||||
output_dir_path = Path(self._output_dir.name)
|
||||
|
||||
with open(output_dir_path / "out.pickle", "wb") as f:
|
||||
attributes = {
|
||||
"configuration": self.configuration,
|
||||
"graph": self.graph,
|
||||
"mlir": self.mlir,
|
||||
}
|
||||
pickle.dump(attributes, f)
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
shutil.make_archive(path, "zip", str(output_dir_path))
|
||||
|
||||
if self.configuration.virtual:
|
||||
self.cleanup()
|
||||
self._output_dir = None
|
||||
|
||||
@staticmethod
|
||||
def load(path: Union[str, Path]) -> "Circuit":
|
||||
"""
|
||||
Load the circuit from the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to load the circuit from
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
circuit loaded from the filesystem
|
||||
"""
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
shutil.unpack_archive(path, str(output_dir_path), "zip")
|
||||
|
||||
with open(output_dir_path / "out.pickle", "rb") as f:
|
||||
attributes = pickle.load(f)
|
||||
|
||||
configuration = attributes["configuration"]
|
||||
graph = attributes["graph"]
|
||||
mlir = attributes["mlir"]
|
||||
|
||||
if configuration.virtual:
|
||||
output_dir.cleanup()
|
||||
return Circuit(configuration, graph, mlir)
|
||||
|
||||
support = LibrarySupport.new(
|
||||
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
|
||||
)
|
||||
compilation_result = support.reload("main")
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Circuit(
|
||||
configuration,
|
||||
graph,
|
||||
mlir,
|
||||
support,
|
||||
compilation_result,
|
||||
server_lambda,
|
||||
output_dir,
|
||||
)
|
||||
self.client = Client(self.server.client_specs, keyset_cache)
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
@@ -310,8 +103,7 @@ class Circuit:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `keygen` method")
|
||||
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self.client_parameters, self._keyset_cache)
|
||||
self.client.keygen(force)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
"""
|
||||
@@ -329,51 +121,7 @@ class Circuit:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
|
||||
|
||||
if len(args) != len(self.graph.input_nodes):
|
||||
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
|
||||
|
||||
sanitized_args = {}
|
||||
for index, node in self.graph.input_nodes.items():
|
||||
arg = args[index]
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
|
||||
)
|
||||
|
||||
expected_value = node.output
|
||||
|
||||
assert_that(isinstance(expected_value.dtype, Integer))
|
||||
expected_dtype = cast(Integer, expected_value.dtype)
|
||||
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
expected_max = expected_dtype.max()
|
||||
expected_shape = expected_value.shape
|
||||
|
||||
actual_min = arg if isinstance(arg, int) else arg.min()
|
||||
actual_max = arg if isinstance(arg, int) else arg.max()
|
||||
actual_shape = () if isinstance(arg, int) else arg.shape
|
||||
|
||||
is_valid = (
|
||||
actual_min >= expected_min
|
||||
and actual_max <= expected_max
|
||||
and actual_shape == expected_shape
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=expected_value.is_encrypted)
|
||||
raise ValueError(
|
||||
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
)
|
||||
|
||||
self.keygen(force=False)
|
||||
return ClientSupport.encrypt_arguments(
|
||||
self.client_parameters,
|
||||
self._keyset,
|
||||
[sanitized_args[i] for i in range(len(sanitized_args))],
|
||||
)
|
||||
return self.client.encrypt(*args)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
"""
|
||||
@@ -391,7 +139,7 @@ class Circuit:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `run` method")
|
||||
|
||||
return self._support.server_call(self._server_lambda, args)
|
||||
return self.server.run(args)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
@@ -412,34 +160,7 @@ class Circuit:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
|
||||
|
||||
results = ClientSupport.decrypt_result(self._keyset, result)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
|
||||
sanitized_results: List[Union[int, np.ndarray]] = []
|
||||
|
||||
for index, node in self.graph.output_nodes.items():
|
||||
expected_value = node.output
|
||||
assert_that(isinstance(expected_value.dtype, Integer))
|
||||
|
||||
expected_dtype = cast(Integer, expected_value.dtype)
|
||||
n = expected_dtype.bit_width
|
||||
|
||||
result = results[index] % (2**n)
|
||||
if expected_dtype.is_signed:
|
||||
if isinstance(result, int):
|
||||
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
|
||||
sanitized_results.append(sanititzed_result)
|
||||
else:
|
||||
result = result.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int8))
|
||||
else:
|
||||
sanitized_results.append(
|
||||
result if isinstance(result, int) else result.astype(np.uint8)
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
return self.client.decrypt(result)
|
||||
|
||||
def encrypt_run_decrypt(self, *args: Any) -> Any:
|
||||
"""
|
||||
@@ -464,5 +185,4 @@ class Circuit:
|
||||
Cleanup the temporary library output directory.
|
||||
"""
|
||||
|
||||
if self._output_dir is not None:
|
||||
self._output_dir.cleanup()
|
||||
self.server.cleanup()
|
||||
|
||||
210
concrete/numpy/compilation/client.py
Normal file
210
concrete/numpy/compilation/client.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Declaration of `Client` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import ClientSupport, KeySet, KeySetCache, PublicArguments, PublicResult
|
||||
|
||||
from ..dtypes.integer import SignedInteger, UnsignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..values.value import Value
|
||||
from .specs import ClientSpecs
|
||||
|
||||
|
||||
class Client:
|
||||
"""
|
||||
Client class, which can be used to manage keys, encrypt arguments and decrypt results.
|
||||
"""
|
||||
|
||||
specs: ClientSpecs
|
||||
|
||||
_keyset: Optional[KeySet]
|
||||
_keyset_cache: Optional[KeySetCache]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_specs: ClientSpecs,
|
||||
keyset_cache_directory: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
self.specs = client_specs
|
||||
|
||||
self._keyset = None
|
||||
self._keyset_cache = None
|
||||
|
||||
if keyset_cache_directory is not None:
|
||||
self._keyset_cache = KeySetCache.new(str(keyset_cache_directory))
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
"""
|
||||
Save the client into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the client
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with open(Path(tmp_dir) / "client.specs.json", "w", encoding="utf-8") as f:
|
||||
f.write(self.specs.serialize())
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
shutil.make_archive(path, "zip", tmp_dir)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: Union[str, Path],
|
||||
keyset_cache_directory: Optional[Union[str, Path]] = None,
|
||||
) -> "Client":
|
||||
"""
|
||||
Load the client from the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to load the client from
|
||||
|
||||
keyset_cache_directory (Optional[Union[str, Path]], default = None):
|
||||
keyset cache directory to use
|
||||
|
||||
Returns:
|
||||
Client:
|
||||
client loaded from the filesystem
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
shutil.unpack_archive(path, tmp_dir, "zip")
|
||||
with open(Path(tmp_dir) / "client.specs.json", "r", encoding="utf-8") as f:
|
||||
client_specs = ClientSpecs.unserialize(f.read())
|
||||
|
||||
return Client(client_specs, keyset_cache_directory)
|
||||
|
||||
def keygen(self, force: bool = False):
|
||||
"""
|
||||
Generate keys required for homomorphic evaluation.
|
||||
|
||||
Args:
|
||||
force (bool, default = False):
|
||||
whether to generate new keys even if keys are already generated
|
||||
"""
|
||||
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self.specs.client_parameters, self._keyset_cache)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
"""
|
||||
Prepare inputs to be run on the circuit.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("inputs" in client_parameters_json)
|
||||
input_specs = client_parameters_json["inputs"]
|
||||
|
||||
if len(args) != len(input_specs):
|
||||
raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}")
|
||||
|
||||
sanitized_args = {}
|
||||
for index, spec in enumerate(input_specs):
|
||||
arg = args[index]
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
|
||||
)
|
||||
|
||||
width = spec["shape"]["width"]
|
||||
shape = tuple(spec["shape"]["dimensions"])
|
||||
is_encrypted = spec["encryption"] is not None
|
||||
|
||||
expected_dtype = UnsignedInteger(width)
|
||||
expected_value = Value(expected_dtype, shape, is_encrypted)
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
expected_max = expected_dtype.max()
|
||||
|
||||
actual_min = arg if isinstance(arg, int) else arg.min()
|
||||
actual_max = arg if isinstance(arg, int) else arg.max()
|
||||
actual_shape = () if isinstance(arg, int) else arg.shape
|
||||
|
||||
is_valid = (
|
||||
actual_min >= expected_min
|
||||
and actual_max <= expected_max
|
||||
and actual_shape == expected_value.shape
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=is_encrypted)
|
||||
raise ValueError(
|
||||
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
)
|
||||
|
||||
self.keygen(force=False)
|
||||
return ClientSupport.encrypt_arguments(
|
||||
self.specs.client_parameters,
|
||||
self._keyset,
|
||||
[sanitized_args[i] for i in range(len(sanitized_args))],
|
||||
)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
result: PublicResult,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Decrypt result of homomorphic evaluaton.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
encrypted result of homomorphic evaluaton
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
results = ClientSupport.decrypt_result(self._keyset, result)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
|
||||
sanitized_results: List[Union[int, np.ndarray]] = []
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("outputs" in client_parameters_json)
|
||||
output_specs = client_parameters_json["outputs"]
|
||||
|
||||
for index, spec in enumerate(output_specs):
|
||||
n = spec["shape"]["width"]
|
||||
expected_dtype = (
|
||||
SignedInteger(n) if self.specs.output_signs[index] else UnsignedInteger(n)
|
||||
)
|
||||
|
||||
result = results[index] % (2**n)
|
||||
if expected_dtype.is_signed:
|
||||
if isinstance(result, int):
|
||||
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
|
||||
sanitized_results.append(sanititzed_result)
|
||||
else:
|
||||
result = result.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int8))
|
||||
else:
|
||||
sanitized_results.append(
|
||||
result if isinstance(result, int) else result.astype(np.uint8)
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
@@ -375,10 +375,12 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
circuit = Circuit.create(self.graph, mlir, self.configuration)
|
||||
circuit = Circuit(self.graph, mlir, self.configuration)
|
||||
if not self.configuration.virtual:
|
||||
assert circuit.client_parameters is not None
|
||||
self.artifacts.add_client_parameters(circuit.client_parameters.serialize())
|
||||
assert circuit.client.specs.client_parameters is not None
|
||||
self.artifacts.add_client_parameters(
|
||||
circuit.client.specs.client_parameters.serialize()
|
||||
)
|
||||
return circuit
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
182
concrete/numpy/compilation/server.py
Normal file
182
concrete/numpy/compilation/server.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Declaration of `Server` class.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from concrete.compiler import (
|
||||
CompilationOptions,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
LibraryCompilationResult,
|
||||
LibraryLambda,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from .configuration import Configuration
|
||||
from .specs import ClientSpecs
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
Client class, which can be used to perform homomorphic computation.
|
||||
"""
|
||||
|
||||
client_specs: ClientSpecs
|
||||
|
||||
_output_dir: Optional[tempfile.TemporaryDirectory]
|
||||
_support: Union[JITSupport, LibrarySupport]
|
||||
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
|
||||
_server_lambda: Union[JITLambda, LibraryLambda]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_specs: ClientSpecs,
|
||||
output_dir: Optional[tempfile.TemporaryDirectory],
|
||||
support: Union[JITSupport, LibrarySupport],
|
||||
compilation_result: Union[JITCompilationResult, LibraryCompilationResult],
|
||||
server_lambda: Union[JITLambda, LibraryLambda],
|
||||
):
|
||||
self.client_specs = client_specs
|
||||
|
||||
self._output_dir = output_dir
|
||||
self._support = support
|
||||
self._compilation_result = compilation_result
|
||||
self._server_lambda = server_lambda
|
||||
|
||||
assert_that(
|
||||
support.load_client_parameters(compilation_result).serialize()
|
||||
== client_specs.client_parameters.serialize()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(mlir: str, output_signs: List[bool], configuration: Configuration) -> "Server":
|
||||
"""
|
||||
Create a server using MLIR and output sign information.
|
||||
|
||||
Args:
|
||||
mlir (str):
|
||||
mlir to compile
|
||||
|
||||
output_signs (List[bool]):
|
||||
sign status of the outputs
|
||||
|
||||
configuration (Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
"""
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
|
||||
options.set_auto_parallelize(configuration.auto_parallelize)
|
||||
options.set_p_error(configuration.p_error)
|
||||
|
||||
if configuration.jit:
|
||||
|
||||
output_dir = None
|
||||
|
||||
support = JITSupport.new()
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
else:
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
support = LibrarySupport.new(
|
||||
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
|
||||
)
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
client_specs = ClientSpecs(support.load_client_parameters(compilation_result), output_signs)
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
"""
|
||||
Save the server into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the server
|
||||
"""
|
||||
|
||||
if self._output_dir is None:
|
||||
raise RuntimeError("Just-in-Time compilation cannot be saved")
|
||||
|
||||
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
|
||||
f.write(self.client_specs.serialize())
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
shutil.make_archive(path, "zip", self._output_dir.name)
|
||||
|
||||
@staticmethod
|
||||
def load(path: Union[str, Path]) -> "Server":
|
||||
"""
|
||||
Load the server from the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to load the server from
|
||||
|
||||
Returns:
|
||||
Server:
|
||||
server loaded from the filesystem
|
||||
"""
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
shutil.unpack_archive(path, str(output_dir_path), "zip")
|
||||
|
||||
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
|
||||
client_specs = ClientSpecs.unserialize(f.read())
|
||||
|
||||
support = LibrarySupport.new(
|
||||
str(output_dir_path),
|
||||
generateCppHeader=False,
|
||||
generateStaticLib=False,
|
||||
)
|
||||
compilation_result = support.reload("main")
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
"""
|
||||
Evaluate using encrypted arguments.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
encrypted arguments of the computation
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of the computation
|
||||
"""
|
||||
|
||||
return self._support.server_call(self._server_lambda, args)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Cleanup the temporary library output directory.
|
||||
"""
|
||||
|
||||
if self._output_dir is not None:
|
||||
self._output_dir.cleanup()
|
||||
59
concrete/numpy/compilation/specs.py
Normal file
59
concrete/numpy/compilation/specs.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Declaration of `ClientSpecs` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from concrete.compiler import ClientParameters
|
||||
|
||||
|
||||
class ClientSpecs:
|
||||
"""
|
||||
ClientSpecs class, to create Client objects.
|
||||
"""
|
||||
|
||||
client_parameters: ClientParameters
|
||||
output_signs: List[bool]
|
||||
|
||||
def __init__(self, client_parameters: ClientParameters, output_signs: List[bool]):
|
||||
self.client_parameters = client_parameters
|
||||
self.output_signs = output_signs
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""
|
||||
Serialize client specs into a string representation.
|
||||
|
||||
Returns:
|
||||
str:
|
||||
string representation of the client specs
|
||||
"""
|
||||
|
||||
client_parameters_json = json.loads(self.client_parameters.serialize())
|
||||
return json.dumps(
|
||||
{
|
||||
"client_parameters": client_parameters_json,
|
||||
"output_signs": self.output_signs,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unserialize(serialized_client_specs: str) -> "ClientSpecs":
|
||||
"""
|
||||
Create client specs from its string representation.
|
||||
|
||||
Args:
|
||||
serialized_client_specs (str):
|
||||
client specs to unserialize
|
||||
|
||||
Returns:
|
||||
ClientSpecs:
|
||||
unserialized client specs
|
||||
"""
|
||||
|
||||
raw_specs = json.loads(serialized_client_specs)
|
||||
|
||||
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
|
||||
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
|
||||
|
||||
return ClientSpecs(client_parameters, raw_specs["output_signs"])
|
||||
@@ -7,8 +7,9 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from concrete.numpy import Circuit
|
||||
from concrete.numpy import Client, ClientSpecs, Configuration, Server
|
||||
from concrete.numpy.compilation import compiler
|
||||
|
||||
|
||||
@@ -98,7 +99,7 @@ def test_circuit_bad_run(helpers):
|
||||
circuit.encrypt_run_decrypt(-1, 11)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Expected argument 0 to be EncryptedScalar<uint4> but it's EncryptedScalar<int1>"
|
||||
"Expected argument 0 to be EncryptedScalar<uint6> but it's EncryptedScalar<int1>"
|
||||
)
|
||||
|
||||
# with negative argument 1
|
||||
@@ -108,7 +109,7 @@ def test_circuit_bad_run(helpers):
|
||||
circuit.encrypt_run_decrypt(1, -11)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Expected argument 1 to be EncryptedScalar<uint5> but it's EncryptedScalar<int5>"
|
||||
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<int5>"
|
||||
)
|
||||
|
||||
# with large argument 0
|
||||
@@ -118,7 +119,7 @@ def test_circuit_bad_run(helpers):
|
||||
circuit.encrypt_run_decrypt(100, 10)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Expected argument 0 to be EncryptedScalar<uint4> but it's EncryptedScalar<uint7>"
|
||||
"Expected argument 0 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
|
||||
)
|
||||
|
||||
# with large argument 1
|
||||
@@ -128,7 +129,7 @@ def test_circuit_bad_run(helpers):
|
||||
circuit.encrypt_run_decrypt(1, 100)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Expected argument 1 to be EncryptedScalar<uint5> but it's EncryptedScalar<uint7>"
|
||||
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
|
||||
)
|
||||
|
||||
|
||||
@@ -167,9 +168,69 @@ def test_circuit_virtual_explicit_api(helpers):
|
||||
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"
|
||||
|
||||
|
||||
def test_circuit_bad_save(helpers):
|
||||
def test_client_server_api(helpers):
|
||||
"""
|
||||
Test `save` method of `Circuit` class with bad parameters.
|
||||
Test client/server API.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
circuit = function.compile(inputset, configuration.fork(jit=False))
|
||||
|
||||
# for coverage
|
||||
circuit.keygen()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_dir_path = Path(tmp_dir)
|
||||
|
||||
server_path = tmp_dir_path / "server.zip"
|
||||
circuit.server.save(server_path)
|
||||
|
||||
client_path = tmp_dir_path / "client.zip"
|
||||
circuit.client.save(client_path)
|
||||
|
||||
circuit.cleanup()
|
||||
|
||||
server = Server.load(server_path)
|
||||
|
||||
serialized_client_specs = server.client_specs.serialize()
|
||||
client_specs = ClientSpecs.unserialize(serialized_client_specs)
|
||||
|
||||
clients = [
|
||||
Client(client_specs, Configuration.insecure_key_cache_location()),
|
||||
Client.load(client_path, Configuration.insecure_key_cache_location()),
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
raw_input = client.encrypt(4)
|
||||
serialized_input = raw_input.serialize()
|
||||
|
||||
unserialized_input = PublicArguments.unserialize(
|
||||
server.client_specs.client_parameters,
|
||||
serialized_input,
|
||||
)
|
||||
evaluation = server.run(unserialized_input)
|
||||
serialized_evaluation = evaluation.serialize()
|
||||
|
||||
unserialized_evaluation = PublicResult.unserialize(
|
||||
client.specs.client_parameters,
|
||||
serialized_evaluation,
|
||||
)
|
||||
output = client.decrypt(unserialized_evaluation)
|
||||
|
||||
assert output == 46
|
||||
|
||||
server.cleanup()
|
||||
|
||||
|
||||
def test_bad_server_save(helpers):
|
||||
"""
|
||||
Test `save` method of `Server` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
@@ -182,69 +243,6 @@ def test_circuit_bad_save(helpers):
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.save("circuit.zip")
|
||||
circuit.server.save("test.zip")
|
||||
|
||||
assert str(excinfo.value) == "JIT Circuits cannot be saved"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"virtual",
|
||||
[False, True],
|
||||
)
|
||||
def test_circuit_save_load(virtual, helpers):
|
||||
"""
|
||||
Test `save`, `load`, and `cleanup` methods of `Circuit` class.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration().fork(jit=False, virtual=virtual)
|
||||
|
||||
def save(base):
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
circuit.save(base / "circuit.zip")
|
||||
circuit.cleanup()
|
||||
|
||||
def load(base):
|
||||
circuit = Circuit.load(base / "circuit.zip")
|
||||
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""",
|
||||
str(circuit),
|
||||
)
|
||||
if virtual:
|
||||
helpers.check_str("Virtual circuits doesn't have MLIR.", circuit.mlir)
|
||||
else:
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
module {
|
||||
func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c42_i7 = arith.constant 42 : i7
|
||||
%0 = "FHE.add_eint_int"(%arg0, %c42_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""",
|
||||
circuit.mlir,
|
||||
)
|
||||
helpers.check_execution(circuit, lambda x: x + 42, 4)
|
||||
|
||||
circuit.cleanup()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp)
|
||||
save(path)
|
||||
load(path)
|
||||
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
|
||||
|
||||
Reference in New Issue
Block a user