feat: introduce explicit encrypt/decrypt/run api

This commit is contained in:
Umut
2022-04-06 17:22:33 +02:00
parent 38ccacca69
commit ce7646f102
5 changed files with 148 additions and 32 deletions

View File

@@ -6,12 +6,24 @@ from pathlib import Path
from typing import List, Optional, Tuple, Union, cast
import numpy as np
from concrete.compiler import CompilerEngine
from concrete.compiler import (
ClientParameters,
ClientSupport,
CompilationOptions,
JITCompilationResult,
JITLambda,
JITSupport,
KeySet,
KeySetCache,
PublicArguments,
PublicResult,
)
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Graph
from ..values import Value
from .configuration import CompilationConfiguration
class Circuit:
@@ -20,11 +32,41 @@ class Circuit:
"""
graph: Graph
engine: CompilerEngine
mlir: str
def __init__(self, graph: Graph, engine: CompilerEngine):
_jit_support: JITSupport
_compilation_result: JITCompilationResult
_client_parameters: ClientParameters
_keyset_cache: KeySetCache
_keyset: KeySet
_server_lambda: JITLambda
def __init__(self, graph: Graph, mlir: str, configuration: CompilationConfiguration):
self.graph = graph
self.engine = engine
self.mlir = mlir
options = CompilationOptions.new("main")
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
options.set_auto_parallelize(configuration.auto_parallelize)
self._jit_support = JITSupport.new()
self._compilation_result = self._jit_support.compile(mlir, options)
self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result)
self._keyset_cache = None
if configuration.use_insecure_key_cache:
assert_that(configuration.enable_unsafe_features)
location = CompilationConfiguration.insecure_key_cache_location()
self._keyset_cache = KeySetCache.new(str(location))
self._keyset = None
self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result)
def __str__(self):
return self.graph.format()
@@ -36,7 +78,7 @@ class Circuit:
save_to: Optional[Union[Path, str]] = None,
) -> Path:
"""
Draw the `self.graph` and optionally save/show the drawing.
Draw `self.graph` and optionally save/show the drawing.
note that this function requires the python `pygraphviz` package
which itself requires the installation of `graphviz` packages
@@ -60,27 +102,35 @@ class Circuit:
return self.graph.draw(show, horizontal, save_to)
def encrypt_run_decrypt(
self,
*args: Union[int, np.ndarray],
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
def keygen(self, force: bool = False):
"""
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
Generate keys required for homomorphic evaluation.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated
"""
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
Prepare inputs to be run on the circuit.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the engine
inputs to the circuit
Returns:
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
result of the homomorphic evaluation
PublicArguments:
encrypted and plain arguments as well as public keys
"""
if len(args) != len(self.graph.input_nodes):
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
sanitized_args = {}
for index, node in self.graph.input_nodes.items():
arg = args[index]
is_valid = isinstance(arg, (int, np.integer)) or (
@@ -116,7 +166,45 @@ class Circuit:
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
results = self.engine.run(*[sanitized_args[i] for i in range(len(sanitized_args))])
self.keygen(force=False)
return ClientSupport.encrypt_arguments(
self._client_parameters,
self._keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
def run(self, args: PublicArguments) -> PublicResult:
"""
Evaluate circuit using encrypted arguments.
Args:
args (PublicArguments):
arguments to the circuit (can be obtained with `encrypt` method of `Circuit`)
Returns:
PublicResult:
encrypted result of homomorphic evaluaton
"""
return self._jit_support.server_call(self._server_lambda, args)
def decrypt(
self,
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Decrypt result of homomorphic evaluaton.
Args:
result (PublicResult):
encrypted result of homomorphic evaluaton
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
"""
results = ClientSupport.decrypt_result(self._keyset, result)
if not isinstance(results, tuple):
results = (results,)
@@ -144,3 +232,21 @@ class Circuit:
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
def encrypt_run_decrypt(
self,
*args: Union[int, np.ndarray],
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
Returns:
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
clear result of homomorphic evaluation
"""
return self.decrypt(self.run(self.encrypt(*args)))

View File

@@ -9,7 +9,6 @@ from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import CompilerEngine
from ..mlir import GraphConverter
from ..representation import Graph
@@ -322,17 +321,7 @@ class Compiler:
print()
engine = CompilerEngine()
if self.configuration.use_insecure_key_cache:
assert self.configuration.enable_unsafe_features
location = CompilationConfiguration.insecure_key_cache_location()
engine.compile_fhe(mlir, unsecure_key_set_cache_path=location)
else:
# this branch is not covered because all tests use key cache to speed up tests
engine.compile_fhe(mlir) # pragma: no cover
return Circuit(self.graph, engine)
return Circuit(self.graph, mlir, self.configuration)
except Exception: # pragma: no cover

View File

@@ -15,16 +15,25 @@ class CompilationConfiguration:
dump_artifacts_on_unexpected_failures: bool
enable_unsafe_features: bool
use_insecure_key_cache: bool
loop_parallelize: bool
dataflow_parallelize: bool
auto_parallelize: bool
def __init__(
self,
dump_artifacts_on_unexpected_failures: bool = True,
enable_unsafe_features: bool = False,
use_insecure_key_cache: bool = False,
loop_parallelize: bool = True,
dataflow_parallelize: bool = False,
auto_parallelize: bool = False,
):
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_unsafe_features = enable_unsafe_features
self.use_insecure_key_cache = use_insecure_key_cache
self.loop_parallelize = loop_parallelize
self.dataflow_parallelize = dataflow_parallelize
self.auto_parallelize = auto_parallelize
if not enable_unsafe_features and use_insecure_key_cache:
raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features")

View File

@@ -83,11 +83,20 @@ Be careful about the inputs, though.
If you were to run with values outside the range of the inputset, the result might not be correct.
```
Today, we cannot simulate a client / server API in python, but it is for very soon. Then, we will have:
- a `keygen` API, which is used to generate both public and private keys
- an `encrypt` API, which happens on the user's device, and is using private keys
- a `run_inference` API, which happens on the untrusted server and only uses public material
- a `encrypt` API, which happens on the user's device to get final clear result, and is using private keys
While `.encrypt_run_decrypt(...)` is a good start for prototyping examples, more advanced usages require control over the different steps that are happening behind the scene, mainly key generation, encryption, execution, and decryption. The different steps can of course be called separately as in the example below:
<!--pytest-codeblocks:cont-->
```python
# generate keys required for encrypted computation
circuit.keygen()
# this will encrypt arguments that require encryption and pack all arguments
# as well as public materials (public keys)
public_args = circuit.encrypt(3, 4)
# this will run the encrypted computation using public materials and inputs provided
encrypted_result = circuit.run(public_args)
# the execution returns the encrypted result which can later be decrypted
decrypted_result = circuit.decrypt(encrypted_result)
```
## Further reading

View File

@@ -104,6 +104,9 @@ class Helpers:
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True,
use_insecure_key_cache=True,
loop_parallelize=True,
dataflow_parallelize=False,
auto_parallelize=False,
)
@staticmethod