diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index b7c125afe..6bcedbf64 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -6,12 +6,24 @@ from pathlib import Path from typing import List, Optional, Tuple, Union, cast import numpy as np -from concrete.compiler import CompilerEngine +from concrete.compiler import ( + ClientParameters, + ClientSupport, + CompilationOptions, + JITCompilationResult, + JITLambda, + JITSupport, + KeySet, + KeySetCache, + PublicArguments, + PublicResult, +) from ..dtypes import Integer from ..internal.utils import assert_that from ..representation import Graph from ..values import Value +from .configuration import CompilationConfiguration class Circuit: @@ -20,11 +32,41 @@ class Circuit: """ graph: Graph - engine: CompilerEngine + mlir: str - def __init__(self, graph: Graph, engine: CompilerEngine): + _jit_support: JITSupport + _compilation_result: JITCompilationResult + + _client_parameters: ClientParameters + + _keyset_cache: KeySetCache + _keyset: KeySet + + _server_lambda: JITLambda + + def __init__(self, graph: Graph, mlir: str, configuration: CompilationConfiguration): self.graph = graph - self.engine = engine + self.mlir = mlir + + options = CompilationOptions.new("main") + + options.set_loop_parallelize(configuration.loop_parallelize) + options.set_dataflow_parallelize(configuration.dataflow_parallelize) + options.set_auto_parallelize(configuration.auto_parallelize) + + self._jit_support = JITSupport.new() + self._compilation_result = self._jit_support.compile(mlir, options) + + self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result) + + self._keyset_cache = None + if configuration.use_insecure_key_cache: + assert_that(configuration.enable_unsafe_features) + location = CompilationConfiguration.insecure_key_cache_location() + self._keyset_cache = KeySetCache.new(str(location)) + self._keyset = None + + self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result) def __str__(self): return self.graph.format() @@ -36,7 +78,7 @@ class Circuit: save_to: Optional[Union[Path, str]] = None, ) -> Path: """ - Draw the `self.graph` and optionally save/show the drawing. + Draw `self.graph` and optionally save/show the drawing. note that this function requires the python `pygraphviz` package which itself requires the installation of `graphviz` packages @@ -60,27 +102,35 @@ class Circuit: return self.graph.draw(show, horizontal, save_to) - def encrypt_run_decrypt( - self, - *args: Union[int, np.ndarray], - ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + def keygen(self, force: bool = False): """ - Encrypt inputs, run the circuit, and decrypt the outputs in one go. + Generate keys required for homomorphic evaluation. + + Args: + force (bool, default = False): + whether to generate new keys even if keys are already generated + """ + + if self._keyset is None or force: + self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache) + + def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments: + """ + Prepare inputs to be run on the circuit. Args: *args (Union[int, numpy.ndarray]): - inputs to the engine + inputs to the circuit Returns: - Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: - result of the homomorphic evaluation + PublicArguments: + encrypted and plain arguments as well as public keys """ if len(args) != len(self.graph.input_nodes): raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}") sanitized_args = {} - for index, node in self.graph.input_nodes.items(): arg = args[index] is_valid = isinstance(arg, (int, np.integer)) or ( @@ -116,7 +166,45 @@ class Circuit: f"Expected argument {index} to be {expected_value} but it's {actual_value}" ) - results = self.engine.run(*[sanitized_args[i] for i in range(len(sanitized_args))]) + self.keygen(force=False) + return ClientSupport.encrypt_arguments( + self._client_parameters, + self._keyset, + [sanitized_args[i] for i in range(len(sanitized_args))], + ) + + def run(self, args: PublicArguments) -> PublicResult: + """ + Evaluate circuit using encrypted arguments. + + Args: + args (PublicArguments): + arguments to the circuit (can be obtained with `encrypt` method of `Circuit`) + + Returns: + PublicResult: + encrypted result of homomorphic evaluaton + """ + + return self._jit_support.server_call(self._server_lambda, args) + + def decrypt( + self, + result: PublicResult, + ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + """ + Decrypt result of homomorphic evaluaton. + + Args: + result (PublicResult): + encrypted result of homomorphic evaluaton + + Returns: + Union[int, numpy.ndarray]: + clear result of homomorphic evaluaton + """ + + results = ClientSupport.decrypt_result(self._keyset, result) if not isinstance(results, tuple): results = (results,) @@ -144,3 +232,21 @@ class Circuit: ) return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) + + def encrypt_run_decrypt( + self, + *args: Union[int, np.ndarray], + ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + """ + Encrypt inputs, run the circuit, and decrypt the outputs in one go. + + Args: + *args (Union[int, numpy.ndarray]): + inputs to the circuit + + Returns: + Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + clear result of homomorphic evaluation + """ + + return self.decrypt(self.run(self.encrypt(*args))) diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index cd1081156..a1ba47287 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -9,7 +9,6 @@ from enum import Enum, unique from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np -from concrete.compiler import CompilerEngine from ..mlir import GraphConverter from ..representation import Graph @@ -322,17 +321,7 @@ class Compiler: print() - engine = CompilerEngine() - - if self.configuration.use_insecure_key_cache: - assert self.configuration.enable_unsafe_features - location = CompilationConfiguration.insecure_key_cache_location() - engine.compile_fhe(mlir, unsecure_key_set_cache_path=location) - else: - # this branch is not covered because all tests use key cache to speed up tests - engine.compile_fhe(mlir) # pragma: no cover - - return Circuit(self.graph, engine) + return Circuit(self.graph, mlir, self.configuration) except Exception: # pragma: no cover diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 4c01a4223..3de01d559 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -15,16 +15,25 @@ class CompilationConfiguration: dump_artifacts_on_unexpected_failures: bool enable_unsafe_features: bool use_insecure_key_cache: bool + loop_parallelize: bool + dataflow_parallelize: bool + auto_parallelize: bool def __init__( self, dump_artifacts_on_unexpected_failures: bool = True, enable_unsafe_features: bool = False, use_insecure_key_cache: bool = False, + loop_parallelize: bool = True, + dataflow_parallelize: bool = False, + auto_parallelize: bool = False, ): self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures self.enable_unsafe_features = enable_unsafe_features self.use_insecure_key_cache = use_insecure_key_cache + self.loop_parallelize = loop_parallelize + self.dataflow_parallelize = dataflow_parallelize + self.auto_parallelize = auto_parallelize if not enable_unsafe_features and use_insecure_key_cache: raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features") diff --git a/docs/user/basics/compiling_and_executing.md b/docs/user/basics/compiling_and_executing.md index 7bc0ba4b5..2b428a9df 100644 --- a/docs/user/basics/compiling_and_executing.md +++ b/docs/user/basics/compiling_and_executing.md @@ -83,11 +83,20 @@ Be careful about the inputs, though. If you were to run with values outside the range of the inputset, the result might not be correct. ``` -Today, we cannot simulate a client / server API in python, but it is for very soon. Then, we will have: - - a `keygen` API, which is used to generate both public and private keys - - an `encrypt` API, which happens on the user's device, and is using private keys - - a `run_inference` API, which happens on the untrusted server and only uses public material - - a `encrypt` API, which happens on the user's device to get final clear result, and is using private keys +While `.encrypt_run_decrypt(...)` is a good start for prototyping examples, more advanced usages require control over the different steps that are happening behind the scene, mainly key generation, encryption, execution, and decryption. The different steps can of course be called separately as in the example below: + + +```python +# generate keys required for encrypted computation +circuit.keygen() +# this will encrypt arguments that require encryption and pack all arguments +# as well as public materials (public keys) +public_args = circuit.encrypt(3, 4) +# this will run the encrypted computation using public materials and inputs provided +encrypted_result = circuit.run(public_args) +# the execution returns the encrypted result which can later be decrypted +decrypted_result = circuit.decrypt(encrypted_result) +``` ## Further reading diff --git a/tests/conftest.py b/tests/conftest.py index d399dabbb..68b9153b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,6 +104,9 @@ class Helpers: dump_artifacts_on_unexpected_failures=False, enable_unsafe_features=True, use_insecure_key_cache=True, + loop_parallelize=True, + dataflow_parallelize=False, + auto_parallelize=False, ) @staticmethod