mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: introduce explicit encrypt/decrypt/run api
This commit is contained in:
@@ -6,12 +6,24 @@ from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
from concrete.compiler import (
|
||||
ClientParameters,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
KeySet,
|
||||
KeySetCache,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
from ..values import Value
|
||||
from .configuration import CompilationConfiguration
|
||||
|
||||
|
||||
class Circuit:
|
||||
@@ -20,11 +32,41 @@ class Circuit:
|
||||
"""
|
||||
|
||||
graph: Graph
|
||||
engine: CompilerEngine
|
||||
mlir: str
|
||||
|
||||
def __init__(self, graph: Graph, engine: CompilerEngine):
|
||||
_jit_support: JITSupport
|
||||
_compilation_result: JITCompilationResult
|
||||
|
||||
_client_parameters: ClientParameters
|
||||
|
||||
_keyset_cache: KeySetCache
|
||||
_keyset: KeySet
|
||||
|
||||
_server_lambda: JITLambda
|
||||
|
||||
def __init__(self, graph: Graph, mlir: str, configuration: CompilationConfiguration):
|
||||
self.graph = graph
|
||||
self.engine = engine
|
||||
self.mlir = mlir
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
|
||||
options.set_auto_parallelize(configuration.auto_parallelize)
|
||||
|
||||
self._jit_support = JITSupport.new()
|
||||
self._compilation_result = self._jit_support.compile(mlir, options)
|
||||
|
||||
self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result)
|
||||
|
||||
self._keyset_cache = None
|
||||
if configuration.use_insecure_key_cache:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
location = CompilationConfiguration.insecure_key_cache_location()
|
||||
self._keyset_cache = KeySetCache.new(str(location))
|
||||
self._keyset = None
|
||||
|
||||
self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result)
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
@@ -36,7 +78,7 @@ class Circuit:
|
||||
save_to: Optional[Union[Path, str]] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Draw the `self.graph` and optionally save/show the drawing.
|
||||
Draw `self.graph` and optionally save/show the drawing.
|
||||
|
||||
note that this function requires the python `pygraphviz` package
|
||||
which itself requires the installation of `graphviz` packages
|
||||
@@ -60,27 +102,35 @@ class Circuit:
|
||||
|
||||
return self.graph.draw(show, horizontal, save_to)
|
||||
|
||||
def encrypt_run_decrypt(
|
||||
self,
|
||||
*args: Union[int, np.ndarray],
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
def keygen(self, force: bool = False):
|
||||
"""
|
||||
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
|
||||
Generate keys required for homomorphic evaluation.
|
||||
|
||||
Args:
|
||||
force (bool, default = False):
|
||||
whether to generate new keys even if keys are already generated
|
||||
"""
|
||||
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
"""
|
||||
Prepare inputs to be run on the circuit.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the engine
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
result of the homomorphic evaluation
|
||||
PublicArguments:
|
||||
encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
|
||||
if len(args) != len(self.graph.input_nodes):
|
||||
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
|
||||
|
||||
sanitized_args = {}
|
||||
|
||||
for index, node in self.graph.input_nodes.items():
|
||||
arg = args[index]
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
@@ -116,7 +166,45 @@ class Circuit:
|
||||
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
)
|
||||
|
||||
results = self.engine.run(*[sanitized_args[i] for i in range(len(sanitized_args))])
|
||||
self.keygen(force=False)
|
||||
return ClientSupport.encrypt_arguments(
|
||||
self._client_parameters,
|
||||
self._keyset,
|
||||
[sanitized_args[i] for i in range(len(sanitized_args))],
|
||||
)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
"""
|
||||
Evaluate circuit using encrypted arguments.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
arguments to the circuit (can be obtained with `encrypt` method of `Circuit`)
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
return self._jit_support.server_call(self._server_lambda, args)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
result: PublicResult,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Decrypt result of homomorphic evaluaton.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
encrypted result of homomorphic evaluaton
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
results = ClientSupport.decrypt_result(self._keyset, result)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
|
||||
@@ -144,3 +232,21 @@ class Circuit:
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
|
||||
def encrypt_run_decrypt(
|
||||
self,
|
||||
*args: Union[int, np.ndarray],
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
clear result of homomorphic evaluation
|
||||
"""
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
|
||||
@@ -9,7 +9,6 @@ from enum import Enum, unique
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
|
||||
from ..mlir import GraphConverter
|
||||
from ..representation import Graph
|
||||
@@ -322,17 +321,7 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
engine = CompilerEngine()
|
||||
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert self.configuration.enable_unsafe_features
|
||||
location = CompilationConfiguration.insecure_key_cache_location()
|
||||
engine.compile_fhe(mlir, unsecure_key_set_cache_path=location)
|
||||
else:
|
||||
# this branch is not covered because all tests use key cache to speed up tests
|
||||
engine.compile_fhe(mlir) # pragma: no cover
|
||||
|
||||
return Circuit(self.graph, engine)
|
||||
return Circuit(self.graph, mlir, self.configuration)
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
|
||||
@@ -15,16 +15,25 @@ class CompilationConfiguration:
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_unsafe_features: bool
|
||||
use_insecure_key_cache: bool
|
||||
loop_parallelize: bool
|
||||
dataflow_parallelize: bool
|
||||
auto_parallelize: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_unsafe_features: bool = False,
|
||||
use_insecure_key_cache: bool = False,
|
||||
loop_parallelize: bool = True,
|
||||
dataflow_parallelize: bool = False,
|
||||
auto_parallelize: bool = False,
|
||||
):
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.use_insecure_key_cache = use_insecure_key_cache
|
||||
self.loop_parallelize = loop_parallelize
|
||||
self.dataflow_parallelize = dataflow_parallelize
|
||||
self.auto_parallelize = auto_parallelize
|
||||
|
||||
if not enable_unsafe_features and use_insecure_key_cache:
|
||||
raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features")
|
||||
|
||||
@@ -83,11 +83,20 @@ Be careful about the inputs, though.
|
||||
If you were to run with values outside the range of the inputset, the result might not be correct.
|
||||
```
|
||||
|
||||
Today, we cannot simulate a client / server API in python, but it is for very soon. Then, we will have:
|
||||
- a `keygen` API, which is used to generate both public and private keys
|
||||
- an `encrypt` API, which happens on the user's device, and is using private keys
|
||||
- a `run_inference` API, which happens on the untrusted server and only uses public material
|
||||
- a `encrypt` API, which happens on the user's device to get final clear result, and is using private keys
|
||||
While `.encrypt_run_decrypt(...)` is a good start for prototyping examples, more advanced usages require control over the different steps that are happening behind the scene, mainly key generation, encryption, execution, and decryption. The different steps can of course be called separately as in the example below:
|
||||
|
||||
<!--pytest-codeblocks:cont-->
|
||||
```python
|
||||
# generate keys required for encrypted computation
|
||||
circuit.keygen()
|
||||
# this will encrypt arguments that require encryption and pack all arguments
|
||||
# as well as public materials (public keys)
|
||||
public_args = circuit.encrypt(3, 4)
|
||||
# this will run the encrypted computation using public materials and inputs provided
|
||||
encrypted_result = circuit.run(public_args)
|
||||
# the execution returns the encrypted result which can later be decrypted
|
||||
decrypted_result = circuit.decrypt(encrypted_result)
|
||||
```
|
||||
|
||||
## Further reading
|
||||
|
||||
|
||||
@@ -104,6 +104,9 @@ class Helpers:
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
enable_unsafe_features=True,
|
||||
use_insecure_key_cache=True,
|
||||
loop_parallelize=True,
|
||||
dataflow_parallelize=False,
|
||||
auto_parallelize=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user